Forráskód Böngészése

feat: 添加模型评估工具集和分类器界面

refactor: 重构推理模块输出目录结构
docs: 添加评估工具文档和README
chore: 更新依赖项和.gitignore
Hannnk 1 hete
szülő
commit
f39e0c94dc

+ 1 - 0
.gitignore

@@ -34,6 +34,7 @@ env/
 # Project specific
 Output/
 models/
+evaluation_results/
 Data
 *.onnx
 *.pkl

+ 538 - 0
README_evaluation.md

@@ -0,0 +1,538 @@
+# 模型性能评估工具
+
+这是一个用于评估UAV和鸟类检测模型性能的完整工具集,包括误报率、漏报率、分类准确率和推理速度等关键指标的评估。
+
+## 文件说明
+
+### 核心文件
+
+- **`model_evaluator.py`** - 主要的模型评估类,包含所有评估逻辑
+- **`run_evaluation.py`** - 简化的运行脚本,提供命令行接口
+- **`evaluation_config.json`** - 配置文件,可自定义评估参数
+- **`generate_test_data.py`** - 测试数据生成器(当没有真实数据时使用)
+- **`example_balanced_evaluation.py`** - 平衡采样评估示例脚本
+
+### 功能特性
+
+✅ **多种评估指标**
+- 混淆矩阵(TP, FP, FN, TN)
+- 精确率、召回率、F1分数
+- 准确率、误报率、漏报率
+- 按类别的详细指标
+
+✅ **推理速度评估**
+- 平均推理时间
+- FPS(每秒处理帧数)
+- 总处理时间统计
+
+✅ **灵活的数据加载**
+- 支持文件夹结构(BIRDS/UAV/OTHERS)
+- 支持标注文件格式
+- 可配置采样大小
+
+✅ **可视化和报告**
+- 自动生成详细的评估报告
+- 混淆矩阵热图
+- 性能指标图表
+
+✅ **多模型支持**
+- uav_and_bird
+- Anti_UAV
+- UAV-250411
+
+## 快速开始
+
+### 1. 环境准备
+
+确保已安装必要的依赖:
+
+```bash
+pip install opencv-python numpy matplotlib seaborn pandas tqdm
+```
+
+### 2. 数据准备
+
+将测试数据按以下结构组织:
+```
+data/
+├── BIRDS/          # 鸟类图像
+├── UAV/            # 无人机图像
+└── OTHERS/         # 其他类型图像(背景等)
+```
+
+### 3. 运行评估
+
+**基本用法:**
+```bash
+python run_evaluation.py --data-dir ./data
+```
+
+**平衡采样评估(推荐):**
+```bash
+# 使用2:2:6比例进行平衡采样
+python run_evaluation.py \
+    --data-dir ./data \
+    --sample-size 1000 \
+    --balanced-sampling \
+    --bird-ratio 0.2 \
+    --uav-ratio 0.2 \
+    --others-ratio 0.6
+```
+
+**自定义参数:**
+```bash
+python run_evaluation.py \
+    --model-type uav_and_bird \
+    --threshold 0.6 \
+    --data-dir ./test_data \
+    --sample-size 1000 \
+    --output-dir ./results
+```
+
+**完整参数列表:**
+```bash
+python run_evaluation.py \
+    --model-type uav_and_bird \
+    --threshold 0.5 \
+    --data-dir ./data \
+    --sample-size 1000 \
+    --balanced-sampling \
+    --bird-ratio 0.2 \
+    --uav-ratio 0.2 \
+    --others-ratio 0.6 \
+    --output-dir ./results \
+    --config ./evaluation_config.json \
+    --generate-data
+```
+
+**参数说明:**
+- `--model-type`: 模型类型 (uav_and_bird, Anti_UAV, UAV-250411)
+- `--threshold`: 置信度阈值 (0.0-1.0)
+- `--data-dir`: 测试数据目录
+- `--sample-size`: 采样大小,None表示使用全部数据
+- `--balanced-sampling`: 启用平衡采样
+- `--bird-ratio`: 鸟类样本比例 (0.0-1.0)
+- `--uav-ratio`: 无人机样本比例 (0.0-1.0)
+- `--others-ratio`: 其他样本比例 (0.0-1.0)
+- `--output-dir`: 输出目录
+- `--config`: 配置文件路径
+- `--generate-data`: 强制重新生成测试数据
+
+**运行示例脚本:**
+```bash
+# 运行平衡采样示例
+python example_balanced_evaluation.py
+```
+
+#### 完整参数列表
+
+```bash
+python run_evaluation.py --help
+```
+
+参数说明:
+- `--config, -c`: 配置文件路径(默认:evaluation_config.json)
+- `--data-dir, -d`: 测试数据目录(默认:Data)
+- `--model-type, -m`: 模型类型(uav_and_bird/Anti_UAV/UAV-250411)
+- `--threshold, -t`: 置信度阈值(默认:0.5)
+- `--sample-size, -s`: 采样大小(None表示全部数据)
+- `--generate-data, -g`: 强制重新生成测试数据
+- `--output-dir, -o`: 输出目录(默认:evaluation_results)
+
+## 配置文件
+
+可以通过修改 `evaluation_config.json` 来自定义评估参数:
+
+```json
+{
+    "evaluation_settings": {
+        "model_type": "uav_and_bird",
+        "confidence_threshold": 0.5,
+        "iou_threshold": 0.5,
+        "sample_size": null,
+        "balanced_sampling": {
+            "enabled": false,
+            "bird_ratio": 0.2,
+            "uav_ratio": 0.2,
+            "others_ratio": 0.6,
+            "description": "平衡采样配置:按指定比例抽取各类别样本"
+        }
+    },
+    "data_settings": {
+        "data_directory": "Data",
+        "use_folder_structure": true,
+        "annotation_file": null
+    },
+    "output_settings": {
+        "output_directory": "evaluation_results",
+        "generate_plots": true,
+        "save_detailed_results": true
+    }
+}
+```
+
+### 平衡采样配置
+
+平衡采样是一种重要的评估策略,可以确保各类别样本按指定比例分布,避免数据不平衡对评估结果的影响:
+
+- **enabled**: 是否启用平衡采样
+- **bird_ratio**: 鸟类样本比例(默认0.2,即20%)
+- **uav_ratio**: 无人机样本比例(默认0.2,即20%)
+- **others_ratio**: 其他样本比例(默认0.6,即60%)
+
+**为什么使用平衡采样?**
+
+1. **避免数据偏差**:如果某类样本过多,可能导致评估结果偏向该类别
+2. **更真实的性能评估**:按实际应用场景的比例分布进行评估
+3. **提高评估效率**:在有限的计算资源下获得更有代表性的结果
+4. **便于对比分析**:不同模型在相同数据分布下的对比更公平
+
+## 高级用法
+
+### 编程接口
+
+```python
+from model_evaluator import ModelEvaluator
+
+# 初始化评估器
+evaluator = ModelEvaluator(
+    model_type='uav_and_bird',
+    threshold=0.5,
+    iou_threshold=0.5
+)
+
+# 加载ground truth
+ground_truth = evaluator.load_ground_truth_from_folder_structure('Data')
+
+# 运行评估
+metrics = evaluator.evaluate_dataset(ground_truth, sample_size=100)
+
+# 生成报告
+report_file = evaluator.generate_report(metrics, 'results')
+plot_file = evaluator.plot_results(metrics, 'results')
+```
+
+### 自定义标注文件
+
+如果有自定义的标注文件,格式应为:
+
+```json
+[
+    {
+        "image_path": "path/to/image1.jpg",
+        "label": "uav",
+        "bbox": [x, y, width, height]  // 可选
+    },
+    {
+        "image_path": "path/to/image2.jpg",
+        "label": "bird",
+        "bbox": [x, y, width, height]  // 可选
+    }
+]
+```
+
+然后使用:
+
+```python
+ground_truth = evaluator.load_ground_truth_from_file('annotations.json')
+```
+
+## 评估指标详解
+
+### 混淆矩阵(Confusion Matrix)
+
+混淆矩阵是评估分类模型性能的基础,包含四个关键指标:
+
+- **真正例(True Positive, TP)**:模型正确地将正样本(如UAV或鸟类)识别为正类的数量。
+  - 例如:无人机被正确识别为无人机。
+
+- **假正例(False Positive, FP)**:模型错误地将负样本识别为正类的数量,即**误报**。
+  - 例如:其他物体被错误识别为无人机或鸟类。
+
+- **假负例(False Negative, FN)**:模型错误地将正样本识别为负类的数量,即**漏报**。
+  - 例如:无人机或鸟类未被识别出来。
+
+- **真负例(True Negative, TN)**:模型正确地将负样本识别为负类的数量。
+  - 例如:其他物体被正确识别为非无人机、非鸟类。
+
+### 性能指标
+
+#### 精确率(Precision)
+
+精确率衡量模型预测为正类的样本中实际为正类的比例,反映了模型的**准确性**。
+
+- **计算公式**:Precision = TP / (TP + FP)
+- **含义**:在所有被模型识别为无人机/鸟类的检测结果中,真正是无人机/鸟类的比例。
+- **重要性**:精确率高意味着误报少,适用于对误报敏感的场景。
+
+#### 召回率(Recall)
+
+召回率衡量模型能够正确识别出的实际正样本的比例,反映了模型的**完整性**。
+
+- **计算公式**:Recall = TP / (TP + FN)
+- **含义**:在所有实际的无人机/鸟类中,被模型正确识别出的比例。
+- **重要性**:召回率高意味着漏报少,适用于对漏报敏感的场景。
+
+#### F1分数(F1 Score)
+
+F1分数是精确率和召回率的调和平均值,提供了模型性能的综合评价。
+
+- **计算公式**:F1 = 2 * (Precision * Recall) / (Precision + Recall)
+- **含义**:平衡考虑精确率和召回率的综合指标。
+- **重要性**:当需要在精确率和召回率之间取得平衡时,F1分数是一个很好的指标。
+
+#### 准确率(Accuracy)
+
+准确率衡量模型正确预测的样本比例。
+
+- **计算公式**:Accuracy = (TP + TN) / (TP + TN + FP + FN)
+- **含义**:所有预测中正确预测的比例。
+- **重要性**:提供整体性能的概览,但在类别不平衡时可能产生误导。
+
+#### 误报率(False Positive Rate, FPR)
+
+误报率衡量模型将负样本错误分类为正样本的比例。
+
+- **计算公式**:FPR = FP / (FP + TN)
+- **含义**:在所有实际的非目标物体中,被错误识别为目标的比例。
+- **重要性**:在安防监控等场景中,低误报率至关重要,避免频繁的虚假警报。
+
+#### 漏报率(False Negative Rate, FNR)
+
+漏报率衡量模型将正样本错误分类为负样本的比例。
+
+- **计算公式**:FNR = FN / (TP + FN)
+- **含义**:在所有实际的目标物体中,未被识别出的比例。
+- **重要性**:在安全关键应用中,低漏报率至关重要,确保不会遗漏重要目标。
+
+### 按类别评估
+
+对于多类别检测(如区分UAV和鸟类),系统会计算每个类别的单独指标:
+
+- **类别精确率**:特定类别的精确率。
+- **类别召回率**:特定类别的召回率。
+- **类别F1分数**:特定类别的F1分数。
+
+这些指标帮助分析模型对不同类别的识别能力,发现模型在特定类别上的优势和不足。
+
+### 推理速度指标
+
+#### 平均推理时间
+
+- **计算方法**:所有图像的推理时间总和除以图像数量。
+- **单位**:秒/图像。
+- **含义**:处理单张图像平均需要的时间。
+
+#### FPS(每秒帧数)
+
+- **计算方法**:1除以平均推理时间。
+- **单位**:帧/秒。
+- **含义**:模型每秒能处理的图像数量,反映实时处理能力。
+
+#### 总处理时间
+
+- **计算方法**:所有图像的处理时间总和,包括预处理、推理和后处理。
+- **单位**:秒。
+- **含义**:评估整个数据集所需的总时间。
+
+### IoU(交并比)
+
+IoU是评估目标检测边界框准确性的关键指标。
+
+- **计算公式**:IoU = (预测框与真实框的交集面积) / (预测框与真实框的并集面积)
+- **取值范围**:0到1,值越大表示预测框与真实框重叠度越高。
+- **阈值设置**:通常设置IoU阈值(如0.5)来判断检测是否正确。
+- **重要性**:评估模型定位目标的精度,IoU高表示不仅检测出目标,而且定位准确。
+
+### 实际应用场景中的指标意义
+
+#### 在UAV检测中的应用
+
+**高精确率的重要性**:
+- 在机场安全监控中,误报会导致不必要的航班延误和资源浪费。
+- 在军事应用中,误报可能引发错误的防御响应。
+- 建议目标:精确率 > 90%
+
+**高召回率的重要性**:
+- 在边境监控中,漏报意味着潜在的安全威胁未被发现。
+- 在禁飞区监控中,漏报可能导致安全事故。
+- 建议目标:召回率 > 95%
+
+#### 在鸟类检测中的应用
+
+**生态保护场景**:
+- 高召回率确保不遗漏珍稀鸟类的观测记录。
+- 适度的精确率可以接受,因为误报的代价相对较低。
+
+**机场鸟击防护**:
+- 需要平衡精确率和召回率,既要避免误报导致的不必要干预,也要确保及时发现鸟类威胁。
+- 建议目标:F1分数 > 85%
+
+#### 推理速度的实际要求
+
+**实时监控系统**:
+- 视频监控:需要至少25-30 FPS以确保流畅的实时处理。
+- 无人机巡逻:10-15 FPS通常足够,因为飞行速度相对较慢。
+
+**批量处理系统**:
+- 历史数据分析:推理速度要求较低,更注重准确性。
+- 大规模数据集处理:需要考虑总处理时间和资源消耗。
+
+### 指标权衡和优化建议
+
+#### 精确率与召回率的权衡
+
+**提高精确率的方法**:
+- 提高置信度阈值(但可能降低召回率)
+- 改进模型训练数据质量
+- 使用更严格的后处理算法
+
+**提高召回率的方法**:
+- 降低置信度阈值(但可能降低精确率)
+- 增加训练数据的多样性
+- 使用数据增强技术
+
+#### 速度与准确性的权衡
+
+**提高推理速度**:
+- 使用模型量化技术
+- 优化输入图像分辨率
+- 使用GPU加速
+- 考虑使用轻量级模型架构
+
+**保持准确性**:
+- 在速度优化时监控性能指标变化
+- 使用知识蒸馏技术
+- 采用渐进式优化策略
+
+### 评估结果解读指南
+
+#### 优秀模型的指标范围
+
+**UAV检测模型**:
+- 精确率:85-95%
+- 召回率:90-98%
+- F1分数:87-96%
+- 推理速度:>20 FPS(实时应用)
+
+**鸟类检测模型**:
+- 精确率:80-90%
+- 召回率:85-95%
+- F1分数:82-92%
+- 推理速度:>15 FPS(实时应用)
+
+#### 问题诊断
+
+**高精确率,低召回率**:
+- 问题:模型过于保守,漏报较多
+- 解决方案:降低置信度阈值,增加训练数据
+
+**低精确率,高召回率**:
+- 问题:模型过于激进,误报较多
+- 解决方案:提高置信度阈值,改进模型训练
+
+**精确率和召回率都低**:
+- 问题:模型性能整体较差
+- 解决方案:重新训练模型,检查数据质量,调整模型架构
+
+**推理速度慢**:
+- 问题:无法满足实时应用需求
+- 解决方案:模型优化,硬件升级,算法改进
+
+## 输出结果
+
+评估完成后,会在输出目录中生成:
+
+1. **`evaluation_report_YYYYMMDD_HHMMSS.txt`** - 详细的文本报告
+2. **`evaluation_plots_YYYYMMDD_HHMMSS.png`** - 可视化图表
+3. **`detailed_results_YYYYMMDD_HHMMSS.json`** - 详细的JSON结果(如果启用)
+
+### 报告内容
+
+- 评估配置和参数
+- 数据集统计信息
+- 混淆矩阵
+- 整体性能指标
+- 按类别的性能指标
+- 推理速度统计
+- 错误分析
+
+### 图表内容
+
+- 混淆矩阵热图
+- 性能指标柱状图
+- 按类别的精确率/召回率对比
+- 推理时间分布
+
+## 故障排除
+
+### 常见问题
+
+1. **模型加载失败**
+   - 检查模型文件是否存在
+   - 确认ONNX Runtime已正确安装
+   - 验证模型类型参数是否正确
+
+2. **找不到测试数据**
+   - 确认数据目录路径正确
+   - 检查文件夹结构是否符合要求
+   - 尝试使用 `--generate-data` 生成测试数据
+
+3. **内存不足**
+   - 使用 `--sample-size` 限制处理的图像数量
+   - 减小图像分辨率
+   - 关闭可视化功能
+
+4. **推理速度慢**
+   - 检查是否启用了GPU加速
+   - 考虑使用较小的图像尺寸
+   - 优化模型参数
+
+### 调试模式
+
+在代码中设置详细日志:
+
+```python
+import logging
+logging.basicConfig(level=logging.DEBUG)
+```
+
+## 扩展功能
+
+### 添加新的评估指标
+
+在 `ModelEvaluator` 类中添加新的方法:
+
+```python
+def calculate_custom_metric(self, results):
+    # 实现自定义指标计算
+    pass
+```
+
+### 支持新的模型类型
+
+在 `inference.py` 中添加新的模型处理逻辑,然后在评估器中引用。
+
+### 自定义可视化
+
+修改 `plot_results` 方法来添加新的图表类型。
+
+## 性能优化建议
+
+1. **批处理**: 对于大量图像,考虑实现批处理推理
+2. **并行处理**: 使用多进程处理图像预处理
+3. **缓存**: 缓存预处理结果以避免重复计算
+4. **GPU加速**: 确保ONNX Runtime使用GPU提供程序
+
+## 贡献
+
+欢迎提交问题报告和功能请求。如果要贡献代码,请:
+
+1. Fork 项目
+2. 创建功能分支
+3. 提交更改
+4. 创建 Pull Request
+
+## 许可证
+
+本项目采用 MIT 许可证。详见 LICENSE 文件。

+ 56 - 0
evaluation_config.json

@@ -0,0 +1,56 @@
+{
+  "evaluation_settings": {
+    "model_type": "uav_and_bird",
+    "model_path": null,
+    "confidence_threshold": 0.5,
+    "iou_threshold": 0.5,
+    "sample_size": null,
+    "random_seed": 42,
+    "calculate_class_metrics": true,
+    "calculate_speed_metrics": true,
+    "balanced_sampling": {
+      "enabled": false,
+      "bird_ratio": 0.2,
+      "uav_ratio": 0.2,
+      "others_ratio": 0.6,
+      "description": "平衡采样配置:按指定比例抽取各类别样本"
+    }
+  },
+  "data_settings": {
+    "data_directory": "d:\\PythonProject\\R360-UAVmodelTool\\Data",
+    "annotation_file": null,
+    "use_folder_structure": true,
+    "supported_formats": ["jpg", "jpeg", "png", "bmp", "tiff"]
+  },
+  "output_settings": {
+    "output_directory": "evaluation_results",
+    "save_detailed_results": true,
+    "generate_plots": true,
+    "save_confusion_matrix": true,
+    "save_per_class_metrics": true
+  },
+  "evaluation_metrics": {
+    "calculate_precision": true,
+    "calculate_recall": true,
+    "calculate_f1_score": true,
+    "calculate_accuracy": true,
+    "calculate_fpr": true,
+    "calculate_fnr": true,
+    "calculate_speed_metrics": true
+  },
+  "visualization_settings": {
+    "plot_confusion_matrix": true,
+    "plot_performance_metrics": true,
+    "plot_inference_time_distribution": true,
+    "plot_per_class_comparison": true,
+    "figure_size": [15, 12],
+    "dpi": 300
+  },
+  "advanced_settings": {
+    "enable_gpu_acceleration": true,
+    "batch_processing": false,
+    "verbose_logging": true,
+    "save_failed_cases": true,
+    "calculate_bbox_metrics": false
+  }
+}

+ 369 - 0
generate_test_data.py

@@ -0,0 +1,369 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+测试数据生成器
+用于生成模拟的测试数据,方便测试模型评估功能
+
+作者: AI Assistant
+创建时间: 2024
+"""
+
+import os
+import cv2
+import numpy as np
+import json
+import random
+from datetime import datetime
+from typing import List, Dict, Tuple
+
+class TestDataGenerator:
+    """
+    测试数据生成器
+    生成模拟的UAV和鸟类图像用于测试
+    """
+    
+    def __init__(self, output_dir: str = None):
+        """
+        初始化生成器
+        
+        Args:
+            output_dir: 输出目录
+        """
+        if output_dir is None:
+            output_dir = os.path.join(os.path.dirname(__file__), 'Data')
+        
+        self.output_dir = output_dir
+        self.image_size = (640, 480)  # 默认图像尺寸
+        
+        # 创建目录结构
+        self.birds_dir = os.path.join(output_dir, 'BIRDS')
+        self.uav_dir = os.path.join(output_dir, 'UAV')
+        self.others_dir = os.path.join(output_dir, 'OTHERS')
+        
+        os.makedirs(self.birds_dir, exist_ok=True)
+        os.makedirs(self.uav_dir, exist_ok=True)
+        os.makedirs(self.others_dir, exist_ok=True)
+    
+    def generate_bird_image(self, image_id: int) -> str:
+        """
+        生成鸟类图像(模拟)
+        
+        Args:
+            image_id: 图像ID
+            
+        Returns:
+            生成的图像路径
+        """
+        # 创建蓝天背景
+        img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
+        img[:, :] = [135, 206, 235]  # 天蓝色背景
+        
+        # 添加一些云朵
+        for _ in range(random.randint(2, 5)):
+            center = (random.randint(50, self.image_size[0]-50), 
+                     random.randint(50, self.image_size[1]-50))
+            radius = random.randint(20, 60)
+            cv2.circle(img, center, radius, (255, 255, 255), -1)
+        
+        # 绘制鸟类形状(简化的V形)
+        bird_x = random.randint(100, self.image_size[0]-100)
+        bird_y = random.randint(100, self.image_size[1]-100)
+        bird_size = random.randint(15, 40)
+        
+        # 绘制鸟的翅膀(V形)
+        pts1 = np.array([[bird_x-bird_size, bird_y], 
+                        [bird_x, bird_y-bird_size//2], 
+                        [bird_x, bird_y+bird_size//2]], np.int32)
+        pts2 = np.array([[bird_x+bird_size, bird_y], 
+                        [bird_x, bird_y-bird_size//2], 
+                        [bird_x, bird_y+bird_size//2]], np.int32)
+        
+        cv2.fillPoly(img, [pts1], (50, 50, 50))  # 深灰色鸟
+        cv2.fillPoly(img, [pts2], (50, 50, 50))
+        
+        # 添加噪声
+        noise = np.random.normal(0, 10, img.shape).astype(np.uint8)
+        img = cv2.add(img, noise)
+        
+        # 保存图像
+        filename = f"bird_{image_id:04d}.jpg"
+        filepath = os.path.join(self.birds_dir, filename)
+        cv2.imwrite(filepath, img)
+        
+        return filepath
+    
+    def generate_uav_image(self, image_id: int) -> str:
+        """
+        生成无人机图像(模拟)
+        
+        Args:
+            image_id: 图像ID
+            
+        Returns:
+            生成的图像路径
+        """
+        # 创建蓝天背景
+        img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
+        img[:, :] = [135, 206, 235]  # 天蓝色背景
+        
+        # 添加一些云朵
+        for _ in range(random.randint(1, 3)):
+            center = (random.randint(50, self.image_size[0]-50), 
+                     random.randint(50, self.image_size[1]-50))
+            radius = random.randint(30, 80)
+            cv2.circle(img, center, radius, (255, 255, 255), -1)
+        
+        # 绘制无人机形状(简化的十字形)
+        uav_x = random.randint(100, self.image_size[0]-100)
+        uav_y = random.randint(100, self.image_size[1]-100)
+        uav_size = random.randint(20, 50)
+        
+        # 绘制无人机机身(矩形)
+        cv2.rectangle(img, 
+                     (uav_x-uav_size//4, uav_y-uav_size//8), 
+                     (uav_x+uav_size//4, uav_y+uav_size//8), 
+                     (80, 80, 80), -1)
+        
+        # 绘制螺旋桨臂
+        cv2.line(img, (uav_x-uav_size, uav_y), (uav_x+uav_size, uav_y), (60, 60, 60), 3)
+        cv2.line(img, (uav_x, uav_y-uav_size), (uav_x, uav_y+uav_size), (60, 60, 60), 3)
+        
+        # 绘制螺旋桨(圆形)
+        for dx, dy in [(-uav_size, 0), (uav_size, 0), (0, -uav_size), (0, uav_size)]:
+            cv2.circle(img, (uav_x+dx, uav_y+dy), uav_size//6, (40, 40, 40), -1)
+        
+        # 添加噪声
+        noise = np.random.normal(0, 8, img.shape).astype(np.uint8)
+        img = cv2.add(img, noise)
+        
+        # 保存图像
+        filename = f"uav_{image_id:04d}.jpg"
+        filepath = os.path.join(self.uav_dir, filename)
+        cv2.imwrite(filepath, img)
+        
+        return filepath
+    
+    def generate_other_image(self, image_id: int) -> str:
+        """
+        生成其他类型图像(负样本)
+        
+        Args:
+            image_id: 图像ID
+            
+        Returns:
+            生成的图像路径
+        """
+        # 随机选择背景类型
+        bg_type = random.choice(['sky', 'landscape', 'urban'])
+        
+        if bg_type == 'sky':
+            # 纯天空背景
+            img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
+            img[:, :] = [135, 206, 235]  # 天蓝色
+            
+            # 添加云朵
+            for _ in range(random.randint(3, 8)):
+                center = (random.randint(0, self.image_size[0]), 
+                         random.randint(0, self.image_size[1]))
+                radius = random.randint(40, 100)
+                cv2.circle(img, center, radius, (255, 255, 255), -1)
+        
+        elif bg_type == 'landscape':
+            # 风景背景
+            img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
+            
+            # 天空部分
+            img[:self.image_size[1]//2, :] = [135, 206, 235]
+            
+            # 地面部分
+            img[self.image_size[1]//2:, :] = [34, 139, 34]  # 森林绿
+            
+            # 添加一些树木形状
+            for _ in range(random.randint(5, 10)):
+                tree_x = random.randint(0, self.image_size[0])
+                tree_y = random.randint(self.image_size[1]//2, self.image_size[1])
+                tree_height = random.randint(30, 80)
+                cv2.rectangle(img, 
+                             (tree_x-5, tree_y), 
+                             (tree_x+5, tree_y-tree_height), 
+                             (139, 69, 19), -1)  # 棕色树干
+                cv2.circle(img, (tree_x, tree_y-tree_height), 15, (0, 100, 0), -1)  # 绿色树冠
+        
+        else:  # urban
+            # 城市背景
+            img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
+            img[:, :] = [169, 169, 169]  # 灰色背景
+            
+            # 添加建筑物
+            for _ in range(random.randint(3, 6)):
+                building_x = random.randint(0, self.image_size[0]-50)
+                building_y = random.randint(self.image_size[1]//3, self.image_size[1])
+                building_w = random.randint(30, 80)
+                building_h = random.randint(50, 150)
+                
+                cv2.rectangle(img, 
+                             (building_x, building_y), 
+                             (building_x+building_w, building_y-building_h), 
+                             (105, 105, 105), -1)  # 深灰色建筑
+        
+        # 添加噪声
+        noise = np.random.normal(0, 12, img.shape).astype(np.uint8)
+        img = cv2.add(img, noise)
+        
+        # 保存图像
+        filename = f"other_{image_id:04d}.jpg"
+        filepath = os.path.join(self.others_dir, filename)
+        cv2.imwrite(filepath, img)
+        
+        return filepath
+    
+    def generate_dataset(self, num_birds: int = 50, num_uavs: int = 50, num_others: int = 100) -> Dict:
+        """
+        生成完整的测试数据集
+        
+        Args:
+            num_birds: 鸟类图像数量
+            num_uavs: 无人机图像数量
+            num_others: 其他图像数量
+            
+        Returns:
+            数据集信息字典
+        """
+        print(f"开始生成测试数据集...")
+        print(f"鸟类图像: {num_birds} 张")
+        print(f"无人机图像: {num_uavs} 张")
+        print(f"其他图像: {num_others} 张")
+        print(f"输出目录: {self.output_dir}")
+        
+        dataset_info = {
+            'creation_time': datetime.now().isoformat(),
+            'total_images': num_birds + num_uavs + num_others,
+            'categories': {
+                'birds': {'count': num_birds, 'files': []},
+                'uavs': {'count': num_uavs, 'files': []},
+                'others': {'count': num_others, 'files': []}
+            }
+        }
+        
+        # 生成鸟类图像
+        print("\n生成鸟类图像...")
+        for i in range(num_birds):
+            filepath = self.generate_bird_image(i)
+            dataset_info['categories']['birds']['files'].append(filepath)
+            if (i + 1) % 10 == 0:
+                print(f"已生成 {i + 1}/{num_birds} 张鸟类图像")
+        
+        # 生成无人机图像
+        print("\n生成无人机图像...")
+        for i in range(num_uavs):
+            filepath = self.generate_uav_image(i)
+            dataset_info['categories']['uavs']['files'].append(filepath)
+            if (i + 1) % 10 == 0:
+                print(f"已生成 {i + 1}/{num_uavs} 张无人机图像")
+        
+        # 生成其他图像
+        print("\n生成其他图像...")
+        for i in range(num_others):
+            filepath = self.generate_other_image(i)
+            dataset_info['categories']['others']['files'].append(filepath)
+            if (i + 1) % 20 == 0:
+                print(f"已生成 {i + 1}/{num_others} 张其他图像")
+        
+        # 保存数据集信息
+        info_file = os.path.join(self.output_dir, 'dataset_info.json')
+        with open(info_file, 'w', encoding='utf-8') as f:
+            json.dump(dataset_info, f, ensure_ascii=False, indent=2)
+        
+        print(f"\n数据集生成完成!")
+        print(f"总计生成 {dataset_info['total_images']} 张图像")
+        print(f"数据集信息已保存到: {info_file}")
+        
+        return dataset_info
+    
+    def create_annotation_file(self, dataset_info: Dict) -> str:
+        """
+        创建标注文件
+        
+        Args:
+            dataset_info: 数据集信息
+            
+        Returns:
+            标注文件路径
+        """
+        annotations = {
+            'info': {
+                'description': 'Generated test dataset for UAV and bird detection',
+                'version': '1.0',
+                'creation_date': dataset_info['creation_time']
+            },
+            'images': []
+        }
+        
+        # 添加鸟类图像标注
+        for filepath in dataset_info['categories']['birds']['files']:
+            annotations['images'].append({
+                'path': filepath,
+                'label': 'bird',
+                'has_target': True,
+                'bboxes': []  # 简化版本,不包含具体边界框
+            })
+        
+        # 添加无人机图像标注
+        for filepath in dataset_info['categories']['uavs']['files']:
+            annotations['images'].append({
+                'path': filepath,
+                'label': 'uav',
+                'has_target': True,
+                'bboxes': []
+            })
+        
+        # 添加其他图像标注
+        for filepath in dataset_info['categories']['others']['files']:
+            annotations['images'].append({
+                'path': filepath,
+                'label': 'others',
+                'has_target': False,
+                'bboxes': []
+            })
+        
+        # 保存标注文件
+        annotation_file = os.path.join(self.output_dir, 'annotations.json')
+        with open(annotation_file, 'w', encoding='utf-8') as f:
+            json.dump(annotations, f, ensure_ascii=False, indent=2)
+        
+        print(f"标注文件已保存到: {annotation_file}")
+        return annotation_file
+
+def main():
+    """
+    主函数 - 生成测试数据
+    """
+    print("测试数据生成器")
+    print("=" * 50)
+    
+    # 配置参数
+    output_dir = r"d:\PythonProject\R360-UAVmodelTool\Data"
+    num_birds = 30
+    num_uavs = 30
+    num_others = 60
+    
+    # 初始化生成器
+    generator = TestDataGenerator(output_dir)
+    
+    # 生成数据集
+    dataset_info = generator.generate_dataset(num_birds, num_uavs, num_others)
+    
+    # 创建标注文件
+    annotation_file = generator.create_annotation_file(dataset_info)
+    
+    print("\n数据集结构:")
+    print(f"├── BIRDS/     ({num_birds} 张图像)")
+    print(f"├── UAV/       ({num_uavs} 张图像)")
+    print(f"├── OTHERS/    ({num_others} 张图像)")
+    print(f"├── dataset_info.json")
+    print(f"└── annotations.json")
+    
+    print("\n现在可以运行 model_evaluator.py 进行模型评估!")
+
+if __name__ == "__main__":
+    main()

+ 365 - 0
image_classifier.py

@@ -0,0 +1,365 @@
+import tkinter as tk
+from tkinter import ttk, filedialog, messagebox
+from PIL import Image, ImageTk, ImageEnhance
+import os
+import shutil
+from pathlib import Path
+import numpy as np
+from collections import Counter
+import cv2
+
+class ImageClassifier:
+    def __init__(self, root):
+        self.root = root
+        self.root.title("图片分类器")
+        self.root.geometry("1200x800")
+        
+        # 当前图片路径
+        self.current_image_path = None
+        # 当前文件夹路径
+        self.current_folder_path = None
+        # 图片列表
+        self.image_list = []
+        # 当前图片索引
+        self.current_index = 0
+        # 图片处理参数
+        self.zoom_factor = 1.0
+        self.rotation_angle = 0
+        self.brightness_factor = 1.0
+        # 天空检测参数
+        self.sky_threshold = 0.7  # 天空占比阈值
+        self.blue_threshold = 0.6  # 蓝色占比阈值
+        # 操作历史记录
+        self.operation_history = []  # 存储操作历史
+        
+        self.setup_ui()
+        self.setup_keyboard_shortcuts()
+        
+    def setup_ui(self):
+        # 创建主框架
+        main_frame = ttk.Frame(self.root)
+        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+        
+        # 左侧控制面板
+        control_frame = ttk.Frame(main_frame)
+        control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5)
+        
+        # 选择文件夹按钮
+        ttk.Button(control_frame, text="选择文件夹 (F)", command=self.select_folder).pack(pady=5)
+        
+        # 分类按钮
+        ttk.Button(control_frame, text="有无人机 (1)", command=lambda: self.classify_image("drone")).pack(pady=5)
+        ttk.Button(control_frame, text="有鸟类 (2)", command=lambda: self.classify_image("bird")).pack(pady=5)
+        ttk.Button(control_frame, text="有鸟类和无人机 (3)", command=lambda: self.classify_image("both")).pack(pady=5)
+        ttk.Button(control_frame, text="无目标 (4)", command=lambda: self.classify_image("none")).pack(pady=5)
+        
+        # 导航按钮
+        nav_frame = ttk.Frame(control_frame)
+        nav_frame.pack(pady=10)
+        ttk.Button(nav_frame, text="上一张 (←)", command=self.prev_image).pack(side=tk.LEFT, padx=5)
+        ttk.Button(nav_frame, text="下一张 (→)", command=self.next_image).pack(side=tk.LEFT, padx=5)
+        
+        # 撤回按钮
+        ttk.Button(control_frame, text="撤回操作 (Z)", command=self.undo_operation).pack(pady=5)
+        
+        # 图片处理控制
+        process_frame = ttk.LabelFrame(control_frame, text="图片处理")
+        process_frame.pack(pady=10, fill=tk.X)
+        
+        # 缩放控制
+        ttk.Label(process_frame, text="缩放:").pack()
+        self.zoom_scale = ttk.Scale(process_frame, from_=0.1, to=3.0, orient=tk.HORIZONTAL, 
+                                  command=self.update_zoom)
+        self.zoom_scale.set(1.0)
+        self.zoom_scale.pack(fill=tk.X, padx=5)
+        
+        # 旋转控制
+        ttk.Label(process_frame, text="旋转:").pack()
+        self.rotation_scale = ttk.Scale(process_frame, from_=0, to=360, orient=tk.HORIZONTAL,
+                                      command=self.update_rotation)
+        self.rotation_scale.set(0)
+        self.rotation_scale.pack(fill=tk.X, padx=5)
+        
+        # 亮度控制
+        ttk.Label(process_frame, text="亮度:").pack()
+        self.brightness_scale = ttk.Scale(process_frame, from_=0.1, to=2.0, orient=tk.HORIZONTAL,
+                                        command=self.update_brightness)
+        self.brightness_scale.set(1.0)
+        self.brightness_scale.pack(fill=tk.X, padx=5)
+        
+        # 天空检测
+        ttk.Label(process_frame, text="天空检测:").pack()
+        self.auto_detect_var = tk.BooleanVar(value=False)
+        ttk.Checkbutton(process_frame, text="自动检测天空图片", 
+                       variable=self.auto_detect_var,
+                       command=self.toggle_auto_detect).pack()
+        
+        # 天空阈值控制
+        ttk.Label(process_frame, text="天空阈值:").pack()
+        self.threshold_scale = ttk.Scale(process_frame, from_=0.5, to=0.95, orient=tk.HORIZONTAL,
+                                       command=self.update_threshold)
+        self.threshold_scale.set(0.7)
+        self.threshold_scale.pack(fill=tk.X, padx=5)
+        
+        # 重置按钮
+        ttk.Button(process_frame, text="重置图片 (R)", command=self.reset_image).pack(pady=5)
+        
+        # 右侧图片显示区域
+        self.image_frame = ttk.Frame(main_frame)
+        self.image_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
+        
+        # 图片标签
+        self.image_label = ttk.Label(self.image_frame)
+        self.image_label.pack(fill=tk.BOTH, expand=True)
+        
+        # 状态标签
+        self.status_label = ttk.Label(self.root, text="请选择文件夹")
+        self.status_label.pack(side=tk.BOTTOM, pady=5)
+        
+        # 添加图片信息显示
+        self.info_label = ttk.Label(self.root, text="")
+        self.info_label.pack(side=tk.BOTTOM, pady=5)
+        
+    def setup_keyboard_shortcuts(self):
+        self.root.bind('<F5>', lambda e: self.select_folder())
+        self.root.bind('<Left>', lambda e: self.prev_image())
+        self.root.bind('<Right>', lambda e: self.next_image())
+        self.root.bind('1', lambda e: self.classify_image("drone"))
+        self.root.bind('2', lambda e: self.classify_image("bird"))
+        self.root.bind('3', lambda e: self.classify_image("both"))
+        self.root.bind('4', lambda e: self.classify_image("none"))
+        self.root.bind('r', lambda e: self.reset_image())
+        self.root.bind('z', lambda e: self.undo_operation())
+        
+    def update_zoom(self, value):
+        self.zoom_factor = float(value)
+        self.show_current_image()
+        
+    def update_rotation(self, value):
+        self.rotation_angle = float(value)
+        self.show_current_image()
+        
+    def update_brightness(self, value):
+        self.brightness_factor = float(value)
+        self.show_current_image()
+        
+    def reset_image(self):
+        self.zoom_factor = 1.0
+        self.rotation_angle = 0
+        self.brightness_factor = 1.0
+        self.zoom_scale.set(1.0)
+        self.rotation_scale.set(0)
+        self.brightness_scale.set(1.0)
+        self.show_current_image()
+        
+    def select_folder(self):
+        folder_path = filedialog.askdirectory()
+        if folder_path:
+            self.current_folder_path = folder_path
+            self.load_images()
+            
+    def load_images(self):
+        self.image_list = []
+        for file in os.listdir(self.current_folder_path):
+            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
+                self.image_list.append(os.path.join(self.current_folder_path, file))
+        
+        if self.image_list:
+            self.current_index = 0
+            self.show_current_image()
+            self.status_label.config(text=f"已加载 {len(self.image_list)} 张图片")
+        else:
+            messagebox.showinfo("提示", "所选文件夹中没有图片")
+            
+    def show_current_image(self):
+        if 0 <= self.current_index < len(self.image_list):
+            self.current_image_path = self.image_list[self.current_index]
+            
+            # 打开图片
+            image = Image.open(self.current_image_path)
+            
+            # 如果启用了自动检测,检查是否为天空图片
+            if self.auto_detect_var.get():
+                if self.is_sky_image(image):
+                    self.classify_image("none")
+                    return
+            
+            # 应用亮度调整
+            enhancer = ImageEnhance.Brightness(image)
+            image = enhancer.enhance(self.brightness_factor)
+            
+            # 应用旋转
+            image = image.rotate(self.rotation_angle, expand=True)
+            
+            # 计算调整后的尺寸,保持宽高比
+            width, height = image.size
+            max_size = (800, 600)
+            ratio = min(max_size[0]/width, max_size[1]/height)
+            new_size = (int(width*ratio*self.zoom_factor), int(height*ratio*self.zoom_factor))
+            image = image.resize(new_size, Image.Resampling.LANCZOS)
+            
+            # 转换为PhotoImage
+            photo = ImageTk.PhotoImage(image)
+            
+            # 更新图片显示
+            self.image_label.config(image=photo)
+            self.image_label.image = photo  # 保持引用
+            
+            # 更新状态
+            self.status_label.config(text=f"图片 {self.current_index + 1}/{len(self.image_list)}")
+            
+            # 显示图片信息
+            file_size = os.path.getsize(self.current_image_path) / 1024  # 转换为KB
+            self.info_label.config(text=f"文件名: {os.path.basename(self.current_image_path)} | 尺寸: {width}x{height} | 大小: {file_size:.1f}KB")
+            
+    def next_image(self):
+        if self.image_list:
+            self.current_index = (self.current_index + 1) % len(self.image_list)
+            self.show_current_image()
+            
+    def prev_image(self):
+        if self.image_list:
+            self.current_index = (self.current_index - 1) % len(self.image_list)
+            self.show_current_image()
+            
+    def classify_image(self, category):
+        if not self.current_image_path:
+            return
+            
+        # 创建分类文件夹
+        category_folders = {
+            "drone": "有无人机",
+            "bird": "有鸟类",
+            "both": "有鸟类和无人机",
+            "none": "无目标"
+        }
+        
+        target_folder = os.path.join(self.current_folder_path, category_folders[category])
+        os.makedirs(target_folder, exist_ok=True)
+        
+        # 移动图片到对应文件夹
+        filename = os.path.basename(self.current_image_path)
+        target_path = os.path.join(target_folder, filename)
+        
+        try:
+            # 记录操作历史(在移动之前)
+            operation_record = {
+                'action': 'classify',
+                'original_path': self.current_image_path,
+                'target_path': target_path,
+                'category': category,
+                'index': self.current_index,
+                'filename': filename
+            }
+            
+            shutil.move(self.current_image_path, target_path)
+            
+            # 添加到操作历史
+            self.operation_history.append(operation_record)
+            
+            # 限制历史记录数量(最多保留50条)
+            if len(self.operation_history) > 50:
+                self.operation_history.pop(0)
+            
+            self.image_list.pop(self.current_index)
+            if self.image_list:
+                self.current_index = self.current_index % len(self.image_list)
+                self.show_current_image()
+            else:
+                self.image_label.config(image='')
+                self.status_label.config(text="所有图片已分类完成")
+                self.info_label.config(text="")
+        except Exception as e:
+            messagebox.showerror("错误", f"移动文件时出错:{str(e)}")
+
+    def is_sky_image(self, image):
+        # 转换为OpenCV格式
+        img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
+        
+        # 转换为HSV颜色空间
+        hsv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2HSV)
+        
+        # 定义天空的HSV范围(包括蓝色和灰色)
+        # 蓝色天空范围
+        lower_blue = np.array([100, 50, 50])
+        upper_blue = np.array([130, 255, 255])
+        
+        # 灰色天空范围
+        lower_gray = np.array([0, 0, 50])
+        upper_gray = np.array([180, 30, 200])
+        
+        # 创建掩码
+        mask_blue = cv2.inRange(hsv, lower_blue, upper_blue)
+        mask_gray = cv2.inRange(hsv, lower_gray, upper_gray)
+        
+        # 合并掩码
+        mask = cv2.bitwise_or(mask_blue, mask_gray)
+        
+        # 计算天空像素占比
+        sky_ratio = np.sum(mask > 0) / (mask.shape[0] * mask.shape[1])
+        
+        # 计算图片上半部分的天空占比
+        upper_half = mask[:mask.shape[0]//2, :]
+        upper_sky_ratio = np.sum(upper_half > 0) / (upper_half.shape[0] * upper_half.shape[1])
+        
+        # 计算图片的纹理特征
+        gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
+        texture = cv2.Laplacian(gray, cv2.CV_64F).var()
+        
+        # 判断是否为天空图片
+        # 1. 整体天空占比超过阈值
+        # 2. 上半部分主要是天空
+        # 3. 纹理变化较小(天空通常比较均匀)
+        return (sky_ratio > self.sky_threshold and 
+                upper_sky_ratio > self.blue_threshold and 
+                texture < 500)  # 纹理阈值可以根据需要调整
+        
+    def toggle_auto_detect(self):
+        if self.auto_detect_var.get() and self.current_image_path:
+            if self.is_sky_image(Image.open(self.current_image_path)):
+                self.classify_image("none")
+                messagebox.showinfo("提示", "检测到天空图片,已自动分类为'无目标'")
+                
+    def update_threshold(self, value):
+        self.sky_threshold = float(value)
+        
+    def undo_operation(self):
+        """撤回最后一次分类操作"""
+        if not self.operation_history:
+            messagebox.showinfo("提示", "没有可撤回的操作")
+            return
+            
+        # 获取最后一次操作
+        last_operation = self.operation_history.pop()
+        
+        if last_operation['action'] == 'classify':
+            try:
+                # 将文件移回原位置
+                shutil.move(last_operation['target_path'], last_operation['original_path'])
+                
+                # 重新加载图片列表
+                self.load_images()
+                
+                # 尝试定位到撤回的图片
+                try:
+                    restored_index = self.image_list.index(last_operation['original_path'])
+                    self.current_index = restored_index
+                except ValueError:
+                    # 如果找不到,就显示第一张
+                    self.current_index = 0
+                
+                if self.image_list:
+                    self.show_current_image()
+                    messagebox.showinfo("成功", f"已撤回对 '{last_operation['filename']}' 的分类操作")
+                else:
+                    messagebox.showinfo("提示", "撤回成功,但当前文件夹中没有图片")
+                    
+            except Exception as e:
+                messagebox.showerror("错误", f"撤回操作失败:{str(e)}")
+                # 如果撤回失败,将操作重新加入历史记录
+                self.operation_history.append(last_operation)
+
+if __name__ == "__main__":
+    root = tk.Tk()
+    app = ImageClassifier(root)
+    root.mainloop()

+ 679 - 0
model_evaluator.py

@@ -0,0 +1,679 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+模型性能评估工具
+用于评估UAV和鸟类检测模型的各项性能指标
+包括:误报率、漏报率、分类准确率、推理速度等
+
+作者: AI Assistant
+创建时间: 2024
+"""
+
+import os
+import cv2
+import numpy as np
+import json
+import time
+import random
+from datetime import datetime
+from typing import List, Dict, Tuple, Optional
+from glob import glob
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import seaborn as sns
+from collections import defaultdict
+import pandas as pd
+
+# 导入现有的推理模块
+from src.core.inference import ONNXDetector
+
+class ModelEvaluator:
+    """
+    模型性能评估器
+    """
+    
+    def __init__(self, model_path: str = None, model_type: str = 'uav_and_bird', 
+                 threshold: float = 0.5, iou_threshold: float = 0.5):
+        """
+        初始化评估器
+        
+        Args:
+            model_path: 模型文件路径
+            model_type: 模型类型
+            threshold: 置信度阈值
+            iou_threshold: IoU阈值
+        """
+        self.model_type = model_type
+        self.threshold = threshold
+        self.iou_threshold = iou_threshold
+        
+        # 初始化检测器
+        self.detector = ONNXDetector(
+            threshold=threshold,
+            model_type=model_type,
+            model_path=model_path,
+            save_empty=False
+        )
+        
+        # 评估结果存储
+        self.evaluation_results = {
+            'true_positives': 0,
+            'false_positives': 0,
+            'false_negatives': 0,
+            'true_negatives': 0,
+            'inference_times': [],
+            'detailed_results': [],
+            'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
+        }
+        
+        # 类别映射
+        if model_type == 'uav_and_bird':
+            self.class_names = {0: 'Bird', 1: 'Drone'}
+            self.class_mapping = {'bird': 0, 'drone': 1, 'none': -1}
+        else:
+            self.class_names = {0: 'Drone'}
+            self.class_mapping = {'drone': 0, 'none': -1}
+    
+    def load_ground_truth_from_folder_structure(self, data_dir: str) -> List[Dict]:
+        """
+        从文件夹结构加载ground truth
+        假设文件夹结构为:
+        data_dir/
+        ├── bird/
+        ├── drone/
+        └── none/
+        
+        Args:
+            data_dir: 数据目录路径
+            
+        Returns:
+            包含图像路径和标签的字典列表
+        """
+        ground_truth = []
+        
+        # 支持的图像格式
+        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
+        
+        # 遍历各个类别文件夹
+        for class_folder in ['bird', 'drone', 'none']:
+            folder_path = os.path.join(data_dir, class_folder)
+            if not os.path.exists(folder_path):
+                continue
+                
+            # 获取该文件夹下的所有图像
+            images = set()  # 使用集合避免重复计算
+            for ext in image_extensions:
+                # 搜索小写扩展名
+                images.update(glob(os.path.join(folder_path, ext)))
+                # 搜索大写扩展名
+                images.update(glob(os.path.join(folder_path, ext.upper())))
+            
+            # 确定类别标签
+            if class_folder == 'bird':
+                label = 'bird'
+            elif class_folder == 'drone':
+                label = 'drone'
+            else:
+                label = 'none'
+            
+            # 添加到ground truth
+            for img_path in images:
+                ground_truth.append({
+                    'image_path': img_path,
+                    'label': label,
+                    'class_id': self.class_mapping.get(label, -1),
+                    'has_target': label != 'none'
+                })
+        
+        return ground_truth
+    
+    def load_ground_truth_from_annotation(self, annotation_file: str) -> List[Dict]:
+        """
+        从标注文件加载ground truth
+        支持COCO格式或自定义JSON格式
+        
+        Args:
+            annotation_file: 标注文件路径
+            
+        Returns:
+            包含图像路径和标签的字典列表
+        """
+        with open(annotation_file, 'r', encoding='utf-8') as f:
+            annotations = json.load(f)
+        
+        ground_truth = []
+        
+        # 处理自定义格式
+        if 'images' in annotations:
+            for img_info in annotations['images']:
+                ground_truth.append({
+                    'image_path': img_info['path'],
+                    'label': img_info.get('label', 'none'),
+                    'class_id': self.class_mapping.get(img_info.get('label', 'none'), -1),
+                    'has_target': img_info.get('has_target', False),
+                    'bboxes': img_info.get('bboxes', [])
+                })
+        
+        return ground_truth
+    
+    def calculate_iou(self, box1: List[float], box2: List[float]) -> float:
+        """
+        计算两个边界框的IoU
+        
+        Args:
+            box1: [x1, y1, x2, y2]
+            box2: [x1, y1, x2, y2]
+            
+        Returns:
+            IoU值
+        """
+        x1 = max(box1[0], box2[0])
+        y1 = max(box1[1], box2[1])
+        x2 = min(box1[2], box2[2])
+        y2 = min(box1[3], box2[3])
+        
+        if x2 <= x1 or y2 <= y1:
+            return 0.0
+        
+        intersection = (x2 - x1) * (y2 - y1)
+        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+        union = area1 + area2 - intersection
+        
+        return intersection / union if union > 0 else 0.0
+    
+    def evaluate_single_image(self, image_info: Dict) -> Dict:
+        """
+        评估单张图像
+        
+        Args:
+            image_info: 包含图像路径和ground truth的字典
+            
+        Returns:
+            评估结果字典
+        """
+        image_path = image_info['image_path']
+        gt_label = image_info['label']
+        gt_class_id = image_info['class_id']
+        gt_has_target = image_info['has_target']
+        gt_bboxes = image_info.get('bboxes', [])
+        
+        # 记录推理时间
+        start_time = time.time()
+        
+        # 进行推理
+        try:
+            # 使用detector的方法进行推理
+            input_data = self.detector.preprocess(image_path)
+            if input_data is None:
+                raise Exception("图像预处理失败")
+            
+            # 执行推理
+            detections_raw = self.detector.inference(input_data)
+            
+            # 后处理获取检测结果
+            image_orig = cv2.imread(image_path)
+            detection_count, processed_img, detections_list = self.detector.postprocess(
+                detections_raw, image_orig, image_path
+            )
+            
+            inference_time = time.time() - start_time
+            
+        except Exception as e:
+            print(f"推理失败: {image_path}, 错误: {e}")
+            return {
+                'image_path': image_path,
+                'gt_label': gt_label,
+                'prediction': 'error',
+                'inference_time': 0,
+                'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0
+            }
+        
+        # 分析结果
+        has_detection = detection_count > 0
+        
+        # 计算混淆矩阵元素
+        tp = fp = fn = tn = 0
+        
+        if gt_has_target and has_detection:
+            tp = 1  # 有目标且检测到
+        elif gt_has_target and not has_detection:
+            fn = 1  # 有目标但未检测到(漏报)
+        elif not gt_has_target and has_detection:
+            fp = 1  # 无目标但检测到(误报)
+        elif not gt_has_target and not has_detection:
+            tn = 1  # 无目标且未检测到
+        
+        # 详细结果
+        result = {
+            'image_path': image_path,
+            'gt_label': gt_label,
+            'gt_has_target': gt_has_target,
+            'detection_count': detection_count,
+            'has_detection': has_detection,
+            'inference_time': inference_time,
+            'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn,
+            'detections': detections_list if 'detections_list' in locals() else []
+        }
+        
+        return result
+    
+    def balanced_sampling(self, ground_truth: List[Dict], total_samples: int, 
+                         bird_ratio: float = 0.2, drone_ratio: float = 0.2, 
+                         none_ratio: float = 0.6) -> List[Dict]:
+        """
+        按比例平衡采样
+        
+        Args:
+            ground_truth: 原始ground truth列表
+            total_samples: 总采样数量
+            bird_ratio: 鸟类样本比例
+            drone_ratio: 无人机样本比例
+            none_ratio: 无目标样本比例
+            
+        Returns:
+            平衡采样后的ground truth列表
+        """
+        # 按类别分组
+        birds = [item for item in ground_truth if item['label'] == 'bird']
+        drones = [item for item in ground_truth if item['label'] == 'drone']
+        nones = [item for item in ground_truth if item['label'] == 'none']
+        
+        # 计算各类别需要的样本数
+        bird_count = int(total_samples * bird_ratio)
+        drone_count = int(total_samples * drone_ratio)
+        none_count = int(total_samples * none_ratio)
+        
+        # 确保不超过实际可用样本数
+        bird_count = min(bird_count, len(birds))
+        drone_count = min(drone_count, len(drones))
+        none_count = min(none_count, len(nones))
+        
+        print(f"平衡采样策略:")
+        print(f"  鸟类: {bird_count}/{len(birds)} ({bird_ratio*100:.1f}%)")
+        print(f"  无人机: {drone_count}/{len(drones)} ({drone_ratio*100:.1f}%)")
+        print(f"  无目标: {none_count}/{len(nones)} ({none_ratio*100:.1f}%)")
+        print(f"  总计: {bird_count + drone_count + none_count} 张图像")
+        
+        # 随机采样各类别
+        sampled_data = []
+        if bird_count > 0:
+            sampled_data.extend(random.sample(birds, bird_count))
+        if drone_count > 0:
+            sampled_data.extend(random.sample(drones, drone_count))
+        if none_count > 0:
+            sampled_data.extend(random.sample(nones, none_count))
+        
+        # 打乱顺序
+        random.shuffle(sampled_data)
+        
+        return sampled_data
+    
+    def evaluate_dataset(self, ground_truth: List[Dict], sample_size: Optional[int] = None,
+                        balanced_sampling: bool = False, bird_ratio: float = 0.2, 
+                        drone_ratio: float = 0.2, none_ratio: float = 0.6) -> Dict:
+        """
+        评估整个数据集
+        
+        Args:
+            ground_truth: ground truth列表
+            sample_size: 采样大小,None表示使用全部数据
+            balanced_sampling: 是否使用平衡采样
+            bird_ratio: 鸟类样本比例(仅在balanced_sampling=True时有效)
+            drone_ratio: 无人机样本比例(仅在balanced_sampling=True时有效)
+            none_ratio: 无目标样本比例(仅在balanced_sampling=True时有效)
+            
+        Returns:
+            评估结果字典
+        """
+        # 采样策略
+        if sample_size and sample_size < len(ground_truth):
+            if balanced_sampling:
+                # 平衡采样
+                ground_truth = self.balanced_sampling(
+                    ground_truth, sample_size, bird_ratio, drone_ratio, none_ratio
+                )
+            else:
+                # 随机采样
+                ground_truth = random.sample(ground_truth, sample_size)
+                print(f"随机采样 {sample_size} 张图像进行评估")
+        elif balanced_sampling:
+            print("警告: 启用了平衡采样但未指定sample_size,将使用全部数据")
+        
+        print(f"开始评估 {len(ground_truth)} 张图像...")
+        
+        # 重置结果
+        self.evaluation_results = {
+            'true_positives': 0,
+            'false_positives': 0,
+            'false_negatives': 0,
+            'true_negatives': 0,
+            'inference_times': [],
+            'detailed_results': [],
+            'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
+        }
+        
+        # 逐张评估
+        for img_info in tqdm(ground_truth, desc="评估进度"):
+            result = self.evaluate_single_image(img_info)
+            
+            # 累计统计
+            self.evaluation_results['true_positives'] += result['tp']
+            self.evaluation_results['false_positives'] += result['fp']
+            self.evaluation_results['false_negatives'] += result['fn']
+            self.evaluation_results['true_negatives'] += result['tn']
+            self.evaluation_results['inference_times'].append(result['inference_time'])
+            self.evaluation_results['detailed_results'].append(result)
+            
+            # 按类别统计
+            gt_label = result['gt_label']
+            # 统计所有类别的指标,包括none类别
+            self.evaluation_results['class_metrics'][gt_label]['tp'] += result['tp']
+            self.evaluation_results['class_metrics'][gt_label]['fp'] += result['fp']
+            self.evaluation_results['class_metrics'][gt_label]['fn'] += result['fn']
+        
+        return self.calculate_metrics()
+    
+    def calculate_metrics(self) -> Dict:
+        """
+        计算各种性能指标
+        
+        Returns:
+            性能指标字典
+        """
+        tp = self.evaluation_results['true_positives']
+        fp = self.evaluation_results['false_positives']
+        fn = self.evaluation_results['false_negatives']
+        tn = self.evaluation_results['true_negatives']
+        
+        # 基本指标
+        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
+        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
+        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
+        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0
+        
+        # 误报率和漏报率
+        false_positive_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
+        false_negative_rate = fn / (fn + tp) if (fn + tp) > 0 else 0
+        
+        # 推理速度统计
+        inference_times = self.evaluation_results['inference_times']
+        avg_inference_time = np.mean(inference_times) if inference_times else 0
+        fps = 1 / avg_inference_time if avg_inference_time > 0 else 0
+        
+        metrics = {
+            'confusion_matrix': {
+                'true_positives': tp,
+                'false_positives': fp,
+                'false_negatives': fn,
+                'true_negatives': tn
+            },
+            'performance_metrics': {
+                'precision': precision,
+                'recall': recall,
+                'f1_score': f1_score,
+                'accuracy': accuracy,
+                'false_positive_rate': false_positive_rate,
+                'false_negative_rate': false_negative_rate
+            },
+            'speed_metrics': {
+                'avg_inference_time': avg_inference_time,
+                'fps': fps,
+                'total_images': len(inference_times),
+                'total_time': sum(inference_times)
+            },
+            'class_metrics': {}
+        }
+        
+        # 按类别计算指标
+        for class_name, class_stats in self.evaluation_results['class_metrics'].items():
+            class_tp = class_stats['tp']
+            class_fp = class_stats['fp']
+            class_fn = class_stats['fn']
+            
+            class_precision = class_tp / (class_tp + class_fp) if (class_tp + class_fp) > 0 else 0
+            class_recall = class_tp / (class_tp + class_fn) if (class_tp + class_fn) > 0 else 0
+            class_f1 = 2 * (class_precision * class_recall) / (class_precision + class_recall) if (class_precision + class_recall) > 0 else 0
+            
+            metrics['class_metrics'][class_name] = {
+                'precision': class_precision,
+                'recall': class_recall,
+                'f1_score': class_f1,
+                'tp': class_tp,
+                'fp': class_fp,
+                'fn': class_fn
+            }
+        
+        return metrics
+    
+    def generate_report(self, metrics: Dict, output_dir: str = None) -> str:
+        """
+        生成评估报告
+        
+        Args:
+            metrics: 性能指标字典
+            output_dir: 输出目录
+            
+        Returns:
+            报告文件路径
+        """
+        if output_dir is None:
+            output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
+        os.makedirs(output_dir, exist_ok=True)
+        
+        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+        report_file = os.path.join(output_dir, f'evaluation_report_{timestamp}.txt')
+        
+        with open(report_file, 'w', encoding='utf-8') as f:
+            f.write("=" * 60 + "\n")
+            f.write("模型性能评估报告\n")
+            f.write("=" * 60 + "\n")
+            f.write(f"评估时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+            f.write(f"模型类型: {self.model_type}\n")
+            f.write(f"置信度阈值: {self.threshold}\n")
+            f.write(f"IoU阈值: {self.iou_threshold}\n\n")
+            
+            # 混淆矩阵
+            cm = metrics['confusion_matrix']
+            f.write("混淆矩阵:\n")
+            f.write("-" * 30 + "\n")
+            f.write(f"真正例 (TP): {cm['true_positives']}\n")
+            f.write(f"假正例 (FP): {cm['false_positives']}\n")
+            f.write(f"假负例 (FN): {cm['false_negatives']}\n")
+            f.write(f"真负例 (TN): {cm['true_negatives']}\n\n")
+            
+            # 性能指标
+            pm = metrics['performance_metrics']
+            f.write("性能指标:\n")
+            f.write("-" * 30 + "\n")
+            f.write(f"精确率 (Precision): {pm['precision']:.4f}\n")
+            f.write(f"召回率 (Recall): {pm['recall']:.4f}\n")
+            f.write(f"F1分数: {pm['f1_score']:.4f}\n")
+            f.write(f"准确率 (Accuracy): {pm['accuracy']:.4f}\n")
+            f.write(f"误报率 (FPR): {pm['false_positive_rate']:.4f}\n")
+            f.write(f"漏报率 (FNR): {pm['false_negative_rate']:.4f}\n\n")
+            
+            # 速度指标
+            sm = metrics['speed_metrics']
+            f.write("速度指标:\n")
+            f.write("-" * 30 + "\n")
+            f.write(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒\n")
+            f.write(f"推理速度 (FPS): {sm['fps']:.2f}\n")
+            f.write(f"总图像数: {sm['total_images']}\n")
+            f.write(f"总耗时: {sm['total_time']:.2f} 秒\n\n")
+            
+            # 按类别指标
+            if metrics['class_metrics']:
+                f.write("按类别性能指标:\n")
+                f.write("-" * 30 + "\n")
+                for class_name, class_metrics in metrics['class_metrics'].items():
+                    f.write(f"\n{class_name.upper()}类别:\n")
+                    f.write(f"  精确率: {class_metrics['precision']:.4f}\n")
+                    f.write(f"  召回率: {class_metrics['recall']:.4f}\n")
+                    f.write(f"  F1分数: {class_metrics['f1_score']:.4f}\n")
+                    f.write(f"  TP: {class_metrics['tp']}, FP: {class_metrics['fp']}, FN: {class_metrics['fn']}\n")
+        
+        print(f"评估报告已保存到: {report_file}")
+        return report_file
+    
+    def plot_results(self, metrics: Dict, output_dir: str = None):
+        """
+        绘制评估结果图表
+        
+        Args:
+            metrics: 性能指标字典
+            output_dir: 输出目录
+        """
+        if output_dir is None:
+            output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
+        os.makedirs(output_dir, exist_ok=True)
+        
+        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+        
+        # 设置中文字体
+        plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
+        plt.rcParams['axes.unicode_minus'] = False
+        
+        # 1. 混淆矩阵热力图
+        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
+        
+        cm = metrics['confusion_matrix']
+        confusion_data = np.array([[cm['true_positives'], cm['false_negatives']],
+                                  [cm['false_positives'], cm['true_negatives']]])
+        
+        sns.heatmap(confusion_data, annot=True, fmt='d', cmap='Blues',
+                   xticklabels=['预测正例', '预测负例'],
+                   yticklabels=['实际正例', '实际负例'], ax=ax1)
+        ax1.set_title('混淆矩阵')
+        
+        # 2. 性能指标柱状图
+        pm = metrics['performance_metrics']
+        metrics_names = ['精确率', '召回率', 'F1分数', '准确率']
+        metrics_values = [pm['precision'], pm['recall'], pm['f1_score'], pm['accuracy']]
+        
+        bars = ax2.bar(metrics_names, metrics_values, color=['skyblue', 'lightgreen', 'orange', 'pink'])
+        ax2.set_title('性能指标')
+        ax2.set_ylim(0, 1)
+        ax2.set_ylabel('分数')
+        
+        # 在柱状图上添加数值标签
+        for bar, value in zip(bars, metrics_values):
+            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
+                    f'{value:.3f}', ha='center', va='bottom')
+        
+        # 3. 推理时间分布直方图
+        inference_times = self.evaluation_results['inference_times']
+        if inference_times:
+            ax3.hist(inference_times, bins=30, color='lightcoral', alpha=0.7)
+            ax3.set_title('推理时间分布')
+            ax3.set_xlabel('推理时间 (秒)')
+            ax3.set_ylabel('频次')
+            ax3.axvline(np.mean(inference_times), color='red', linestyle='--',
+                       label=f'平均值: {np.mean(inference_times):.3f}s')
+            ax3.legend()
+        
+        # 4. 按类别性能对比
+        if metrics['class_metrics']:
+            class_names = list(metrics['class_metrics'].keys())
+            precisions = [metrics['class_metrics'][name]['precision'] for name in class_names]
+            recalls = [metrics['class_metrics'][name]['recall'] for name in class_names]
+            f1_scores = [metrics['class_metrics'][name]['f1_score'] for name in class_names]
+            
+            x = np.arange(len(class_names))
+            width = 0.25
+            
+            ax4.bar(x - width, precisions, width, label='精确率', color='skyblue')
+            ax4.bar(x, recalls, width, label='召回率', color='lightgreen')
+            ax4.bar(x + width, f1_scores, width, label='F1分数', color='orange')
+            
+            ax4.set_title('按类别性能对比')
+            ax4.set_xlabel('类别')
+            ax4.set_ylabel('分数')
+            ax4.set_xticks(x)
+            ax4.set_xticklabels(class_names)
+            ax4.legend()
+            ax4.set_ylim(0, 1)
+        
+        plt.tight_layout()
+        plot_file = os.path.join(output_dir, f'evaluation_plots_{timestamp}.png')
+        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
+        plt.close()
+        
+        print(f"评估图表已保存到: {plot_file}")
+        return plot_file
+
+def main():
+    """
+    主函数 - 演示如何使用评估器
+    """
+    print("模型性能评估工具")
+    print("=" * 50)
+    
+    # 配置参数
+    data_dir = r"d:\PythonProject\R360-UAVmodelTool\Data"
+    model_type = 'uav_and_bird'
+    threshold = 0.5
+    sample_size = None  # None表示使用全部数据,可以设置为具体数字进行采样
+    
+    # 初始化评估器
+    evaluator = ModelEvaluator(
+        model_type=model_type,
+        threshold=threshold
+    )
+    
+    # 加载ground truth
+    print("加载ground truth数据...")
+    ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir)
+    
+    if not ground_truth:
+        print(f"警告: 在 {data_dir} 中未找到测试数据")
+        print("请确保数据目录结构如下:")
+        print("Data/")
+        print("├── BIRDS/     # 鸟类图像")
+        print("├── UAV/       # 无人机图像")
+        print("└── OTHERS/    # 其他图像(负样本)")
+        return
+    
+    print(f"找到 {len(ground_truth)} 张测试图像")
+    
+    # 显示数据分布
+    label_counts = {}
+    for item in ground_truth:
+        label = item['label']
+        label_counts[label] = label_counts.get(label, 0) + 1
+    
+    print("数据分布:")
+    for label, count in label_counts.items():
+        print(f"  {label}: {count} 张")
+    
+    # 开始评估
+    print("\n开始模型评估...")
+    metrics = evaluator.evaluate_dataset(ground_truth, sample_size)
+    
+    # 生成报告
+    print("\n生成评估报告...")
+    report_file = evaluator.generate_report(metrics)
+    
+    # 绘制图表
+    print("生成评估图表...")
+    plot_file = evaluator.plot_results(metrics)
+    
+    # 打印关键指标
+    print("\n=== 评估结果摘要 ===")
+    pm = metrics['performance_metrics']
+    sm = metrics['speed_metrics']
+    
+    print(f"精确率: {pm['precision']:.4f}")
+    print(f"召回率: {pm['recall']:.4f}")
+    print(f"F1分数: {pm['f1_score']:.4f}")
+    print(f"准确率: {pm['accuracy']:.4f}")
+    print(f"误报率: {pm['false_positive_rate']:.4f}")
+    print(f"漏报率: {pm['false_negative_rate']:.4f}")
+    print(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒")
+    print(f"推理速度: {sm['fps']:.2f} FPS")
+    
+    print(f"\n详细报告: {report_file}")
+    print(f"评估图表: {plot_file}")
+
+if __name__ == "__main__":
+    main()

+ 13 - 9
requirements.txt

@@ -1,9 +1,13 @@
-numpy>=1.21.0
-Pillow>=9.0.0
-opencv-python>=4.5.0
-onnxruntime>=1.8.0
-onnx==1.17.0
-tqdm==4.66.1
-openpyxl==3.1.5
-pytest==7.4.0
-pytest-cov==4.1.0 
+numpy>=1.26.0
+opencv-python>=4.8.0
+onnxruntime-gpu==1.15.1
+onnx>=1.17.0
+pandas>=2.0.0
+openpyxl>=3.1.0
+xlsxwriter>=3.0.0
+matplotlib>=3.7.0
+seaborn>=0.12.0
+pillow>=10.0.0
+tqdm>=4.65.0
+typing-extensions>=4.0.0
+

+ 310 - 0
run_evaluation.py

@@ -0,0 +1,310 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+模型评估运行脚本
+简化的入口点,用于快速开始模型评估
+
+使用方法:
+1. 如果没有测试数据,先运行: python generate_test_data.py
+2. 然后运行: python run_evaluation.py
+
+作者: AI Assistant
+创建时间: 2024
+"""
+
+import os
+import sys
+import json
+import argparse
+import random
+import numpy as np
+from datetime import datetime
+
+# 添加项目根目录到Python路径
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from model_evaluator import ModelEvaluator
+from generate_test_data import TestDataGenerator
+
+def check_data_availability(data_dir: str) -> bool:
+    """
+    检查测试数据是否可用
+    
+    Args:
+        data_dir: 数据目录路径
+        
+    Returns:
+        数据是否可用
+    """
+    if not os.path.exists(data_dir):
+        return False
+    
+    # 检查各个子目录
+    subdirs = ['bird', 'drone', 'none']
+    total_images = 0
+    
+    for subdir in subdirs:
+        subdir_path = os.path.join(data_dir, subdir)
+        if os.path.exists(subdir_path):
+            # 计算图像数量
+            image_files = set()  # 使用集合避免重复计算
+            for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
+                import glob
+                # 搜索小写扩展名
+                image_files.update(glob.glob(os.path.join(subdir_path, ext)))
+                # 搜索大写扩展名
+                image_files.update(glob.glob(os.path.join(subdir_path, ext.upper())))
+            total_images += len(image_files)
+    
+    return total_images > 0
+
+def load_config(config_file: str = 'evaluation_config.json') -> dict:
+    """
+    加载配置文件
+    
+    Args:
+        config_file: 配置文件路径
+        
+    Returns:
+        配置字典
+    """
+    if os.path.exists(config_file):
+        with open(config_file, 'r', encoding='utf-8') as f:
+            return json.load(f)
+    else:
+        # 返回默认配置
+        return {
+            "evaluation_settings": {
+                "model_type": "uav_and_bird",
+                "confidence_threshold": 0.5,
+                "iou_threshold": 0.5,
+                "sample_size": None
+            },
+            "data_settings": {
+                "data_directory": "Data",
+                "use_folder_structure": True
+            },
+            "output_settings": {
+                "output_directory": "evaluation_results",
+                "generate_plots": True
+            }
+        }
+
+def generate_test_data_if_needed(data_dir: str, force_generate: bool = False) -> bool:
+    """
+    如果需要,生成测试数据
+    
+    Args:
+        data_dir: 数据目录
+        force_generate: 是否强制重新生成
+        
+    Returns:
+        是否成功生成或已存在数据
+    """
+    if not check_data_availability(data_dir) or force_generate:
+        print("未找到测试数据,正在生成模拟数据...")
+        
+        generator = TestDataGenerator(data_dir)
+        dataset_info = generator.generate_dataset(
+            num_birds=20,
+            num_uavs=20, 
+            num_others=40
+        )
+        generator.create_annotation_file(dataset_info)
+        
+        print("测试数据生成完成!")
+        return True
+    else:
+        print("找到现有测试数据")
+        return True
+
+def main():
+    """
+    主函数
+    """
+    parser = argparse.ArgumentParser(description='模型性能评估工具')
+    parser.add_argument('--config', '-c', default='evaluation_config.json',
+                       help='配置文件路径')
+    parser.add_argument('--data-dir', '-d', default='Data',
+                       help='测试数据目录')
+    parser.add_argument('--model-type', '-m', default='uav_and_bird',
+                       choices=['uav_and_bird', 'Anti_UAV', 'UAV-250411'],
+                       help='模型类型')
+    parser.add_argument('--threshold', '-t', type=float, default=0.5,
+                       help='置信度阈值')
+    parser.add_argument('--sample-size', '-s', type=int, default=None,
+                       help='采样大小(None表示使用全部数据)')
+    parser.add_argument('--balanced-sampling', action='store_true',
+                       help='启用平衡采样(按比例抽取样本)')
+    parser.add_argument('--bird-ratio', type=float, default=0.2,
+                       help='鸟类样本比例(默认0.2,即20%)')
+    parser.add_argument('--drone-ratio', type=float, default=0.2,
+                       help='无人机样本比例(默认0.2,即20%)')
+    parser.add_argument('--none-ratio', type=float, default=0.6,
+                       help='其他样本比例(默认0.6,即60%)')
+    parser.add_argument('--generate-data', '-g', action='store_true',
+                       help='强制重新生成测试数据')
+    parser.add_argument('--output-dir', '-o', default='evaluation_results',
+                       help='输出目录')
+    
+    args = parser.parse_args()
+    
+    # 加载配置
+    config = load_config(args.config)
+    
+    # 如果命令行没有指定,则从配置文件读取平衡采样设置
+    if not args.balanced_sampling and config['evaluation_settings'].get('balanced_sampling', {}).get('enabled', False):
+        args.balanced_sampling = True
+        balanced_config = config['evaluation_settings']['balanced_sampling']
+        args.bird_ratio = balanced_config.get('bird_ratio', 0.2)
+        args.drone_ratio = balanced_config.get('drone_ratio', 0.2)
+        args.none_ratio = balanced_config.get('none_ratio', 0.6)
+        print(f"从配置文件加载平衡采样设置")
+    
+    print("=" * 60)
+    print("UAV和鸟类检测模型性能评估工具")
+    print("=" * 60)
+    print(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+    print(f"模型类型: {args.model_type}")
+    print(f"置信度阈值: {args.threshold}")
+    print(f"数据目录: {args.data_dir}")
+    print(f"输出目录: {args.output_dir}")
+    print(f"配置文件: {args.config}")
+    
+    # 检查并生成测试数据
+    data_dir = os.path.abspath(args.data_dir)
+    if not generate_test_data_if_needed(data_dir, args.generate_data):
+        print("错误: 无法生成或找到测试数据")
+        return 1
+    
+    # 初始化评估器
+    print("\n初始化模型评估器...")
+    try:
+        evaluator = ModelEvaluator(
+            model_type=args.model_type,
+            threshold=args.threshold,
+            iou_threshold=config['evaluation_settings'].get('iou_threshold', 0.5)
+        )
+        print("模型加载成功")
+    except Exception as e:
+        print(f"错误: 模型加载失败 - {e}")
+        return 1
+    
+    # 加载ground truth
+    print("\n加载测试数据...")
+    try:
+        ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir)
+        if not ground_truth:
+            print("错误: 未找到有效的测试数据")
+            return 1
+        
+        print(f"找到 {len(ground_truth)} 张测试图像")
+        
+        # 显示数据分布
+        label_counts = {}
+        for item in ground_truth:
+            label = item['label']
+            label_counts[label] = label_counts.get(label, 0) + 1
+        
+        print("数据分布:")
+        for label, count in label_counts.items():
+            print(f"  {label.upper()}: {count} 张")
+            
+    except Exception as e:
+        print(f"错误: 加载测试数据失败 - {e}")
+        return 1
+    
+    # 验证比例参数
+    if args.balanced_sampling:
+        total_ratio = args.bird_ratio + args.drone_ratio + args.none_ratio
+        if abs(total_ratio - 1.0) > 0.01:  # 允许小的浮点误差
+            print(f"警告: 比例总和为 {total_ratio:.3f},不等于1.0,将自动归一化")
+            # 归一化比例
+            args.bird_ratio /= total_ratio
+            args.drone_ratio /= total_ratio
+            args.none_ratio /= total_ratio
+            print(f"归一化后比例: 鸟类={args.bird_ratio:.3f}, 无人机={args.drone_ratio:.3f}, 其他={args.none_ratio:.3f}")
+    
+    # 开始评估
+    print("\n" + "=" * 40)
+    print("开始模型性能评估")
+    if args.balanced_sampling:
+        print(f"采样策略: 平衡采样 (鸟类:{args.bird_ratio*100:.1f}%, 无人机:{args.drone_ratio*100:.1f}%, 其他:{args.none_ratio*100:.1f}%)")
+    else:
+        print("采样策略: 随机采样")
+    print("=" * 40)
+    
+    try:
+        metrics = evaluator.evaluate_dataset(
+            ground_truth, 
+            sample_size=args.sample_size,
+            balanced_sampling=args.balanced_sampling,
+            bird_ratio=args.bird_ratio,
+            drone_ratio=args.drone_ratio,
+            none_ratio=args.none_ratio
+        )
+        print("\n评估完成!")
+    except Exception as e:
+        print(f"错误: 评估过程失败 - {e}")
+        return 1
+    
+    # 生成报告和图表
+    print("\n生成评估报告...")
+    try:
+        report_file = evaluator.generate_report(metrics, args.output_dir)
+        
+        if config['output_settings'].get('generate_plots', True):
+            print("生成评估图表...")
+            plot_file = evaluator.plot_results(metrics, args.output_dir)
+        
+    except Exception as e:
+        print(f"警告: 生成报告或图表时出错 - {e}")
+    
+    # 显示关键结果
+    print("\n" + "=" * 40)
+    print("评估结果摘要")
+    print("=" * 40)
+    
+    pm = metrics['performance_metrics']
+    sm = metrics['speed_metrics']
+    cm = metrics['confusion_matrix']
+    
+    print(f"\n混淆矩阵:")
+    print(f"  真正例 (TP): {cm['true_positives']}")
+    print(f"  假正例 (FP): {cm['false_positives']}")
+    print(f"  假负例 (FN): {cm['false_negatives']}")
+    print(f"  真负例 (TN): {cm['true_negatives']}")
+    
+    print(f"\n性能指标:")
+    print(f"  精确率 (Precision): {pm['precision']:.4f}")
+    print(f"  召回率 (Recall): {pm['recall']:.4f}")
+    print(f"  F1分数: {pm['f1_score']:.4f}")
+    print(f"  准确率 (Accuracy): {pm['accuracy']:.4f}")
+    print(f"  误报率 (FPR): {pm['false_positive_rate']:.4f}")
+    print(f"  漏报率 (FNR): {pm['false_negative_rate']:.4f}")
+    
+    print(f"\n速度指标:")
+    print(f"  平均推理时间: {sm['avg_inference_time']:.4f} 秒")
+    print(f"  推理速度: {sm['fps']:.2f} FPS")
+    print(f"  总处理图像: {sm['total_images']} 张")
+    print(f"  总耗时: {sm['total_time']:.2f} 秒")
+    
+    # 按类别显示结果
+    if metrics['class_metrics']:
+        print(f"\n按类别性能:")
+        for class_name, class_metrics in metrics['class_metrics'].items():
+            print(f"  {class_name.upper()}类别:")
+            print(f"    精确率: {class_metrics['precision']:.4f}")
+            print(f"    召回率: {class_metrics['recall']:.4f}")
+            print(f"    F1分数: {class_metrics['f1_score']:.4f}")
+    
+    print(f"\n详细报告已保存到: {report_file}")
+    if 'plot_file' in locals():
+        print(f"评估图表已保存到: {plot_file}")
+    
+    print("\n评估完成!")
+    return 0
+
+if __name__ == "__main__":
+    exit_code = main()
+    sys.exit(exit_code)

+ 0 - 129
src/core/ali_image_validation.py

@@ -1,129 +0,0 @@
-import os
-import base64
-import json
-import time
-from PIL import Image
-from io import BytesIO
-import logging
-from openai import OpenAI
-from retry import retry
-import requests
-
-class AliImageValidator:
-    def __init__(self, api_key="sk-ccfcdd12fd434d0dab1406958663df9d"):
-        self.client = OpenAI(
-            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
-            api_key=api_key or os.environ.get("DASHSCOPE_API_KEY")
-        )
-        if not self.client.api_key:
-            raise ValueError("未找到API密钥,请设置环境变量DASHSCOPE_API_KEY或通过参数传递")
-
-        # 日志配置已移除
-
-
-    @retry(tries=3, delay=1, backoff=2, max_delay=5, exceptions=(Exception,))
-    def analyze_image(self, image_url):
-        if not image_url.startswith(('http://', 'https://')):
-            return {
-                'is_false_positive': False,
-                'uav_detected': False,
-                'probability': 0.0,
-                'reason': '无效的图片URL',
-                'response_time': 0.0
-            }
-
-        system_prompt = """
-你是一个专业的图像分析专家,请严格按照以下要求响应:
-1. 输出必须为合法的JSON格式,不能是纯自然语言段落
-2. 包含三个字段:
-   probability(误报概率0-1)
-   is_uav(是否是无人机 bool)
-   reason(分析理由)
-3. 不要包含任何额外说明
-
-请分析该监控画面:
-- 判断是否为安全误报
-- 识别画面中是否存在无人机
-- 给出详细分析理由
-"""
-
-        try:
-            start_time = time.time()
-            model_name = "qwen-vl-max-latest"
-            # 发送请求到阿里云
-            response = self.client.chat.completions.create(
-                model=model_name,
-                messages=[
-                    {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
-                    {
-                        "role": "user",
-                        "content": [
-                            {
-                                "type": "image_url",
-                                "image_url": {
-                                    "url": image_url
-                                }
-                            },
-                            {"type": "text", "text": "请分析该监控画面是否为误报"}
-                        ]
-                    }
-                ]
-            )
-            end_time = time.time()
-            response_time = end_time - start_time
-            
-            try:
-                result = response.choices[0].message.content
-                # 获取响应内容
-                try:
-                    result_json = json.loads(result.split('```json')[1].split('```')[0].strip())  # 提取markdown代码块中的JSON
-                    
-                    return {
-                        'is_false_positive': result_json.get('is_false_positive', False),
-                        'uav_detected': result_json.get('is_uav', False),
-                        'probability': result_json.get('probability', 0.0),
-                        'reason': result_json.get('reason', '无分析结果'),
-                        'response_time': response_time
-                    }
-                except (IndexError, json.JSONDecodeError, KeyError) as e:
-                    # 响应解析失败
-                    return {
-                        'is_false_positive': "误报" in result,
-                        'uav_detected': False,
-                        'probability': 0.0,
-                        'reason': '响应格式异常',
-                        'response_time': response_time
-                    }
-            except (json.JSONDecodeError, ValueError) as e:
-                # 响应解析失败
-                return {
-                    'is_false_positive': "误报" in result,
-                    'uav_detected': False,
-                    'probability': 0.0,
-                    'reason': '解析失败',
-                    'response_time': response_time
-                }
-
-        except Exception as e:
-            end_time = time.time()
-            response_time = end_time - start_time
-            # API调用失败
-            return {
-                'is_false_positive': False,
-                'uav_detected': False,
-                'probability': 0.0,
-                'reason': 'API调用失败',
-                'response_time': response_time
-            }
-
-if __name__ == "__main__":
-    import argparse
-    
-    parser = argparse.ArgumentParser(description='阿里云图像误报分析工具')
-    parser.add_argument('image_url', help='需要分析的图片URL地址')
-    args = parser.parse_args()
-
-    validator = AliImageValidator()
-    result = validator.analyze_image(args.image_url)
-    
-    # 分析完成

+ 0 - 0
src/core/detect.py


+ 25 - 12
src/core/inference.py

@@ -10,6 +10,7 @@ from datetime import datetime
 from tqdm import tqdm
 from src.analysis.report_generator import ReportGenerator
 import onnx
+import numpy as np
 
 class ONNXDetector:
     def select_input_directory(self) -> Optional[str]:
@@ -89,8 +90,13 @@ class ONNXDetector:
         else:
             self.output_dir = os.path.join(output_base, output_dir or f"output_{timestamp}")
             
-        # 创建目标存储目录
-        self.targets_dir = os.path.join(self.output_dir, 'targets')
+        # 创建三个存储目录:原图、画框图、目标
+        self.original_dir = os.path.join(self.output_dir, 'original')  # 原图目录
+        self.annotated_dir = os.path.join(self.output_dir, 'annotated')  # 画框图目录
+        self.targets_dir = os.path.join(self.output_dir, 'targets')  # 目标目录
+        
+        os.makedirs(self.original_dir, exist_ok=True)
+        os.makedirs(self.annotated_dir, exist_ok=True)
         os.makedirs(self.targets_dir, exist_ok=True)
 
     def preprocess(self, image_path: str) -> Optional[np.ndarray]:
@@ -430,18 +436,25 @@ class ONNXDetector:
         }
 
         filename = f"{os.path.splitext(os.path.basename(image_path))[0]}.jpg"
-        # 创建输出图片目录
-        output_img_dir = os.path.join(self.output_dir, 'imgs')
-        os.makedirs(output_img_dir, exist_ok=True)
-        output_path = os.path.join(output_img_dir, filename)
         
-        # 根据设置保存图片
+        # 保存原图到original目录(仅当检测到目标时)
+        if valid_count > 0:
+            original_img = cv2.imread(image_path)
+            original_output_path = os.path.join(self.original_dir, filename)
+            cv2.imwrite(original_output_path, original_img)
+        else:
+            original_output_path = None
+        
+        # 保存画框图到annotated目录
+        annotated_output_path = os.path.join(self.annotated_dir, filename)
+        
+        # 根据设置保存画框图
         if valid_count > 0 or self.save_empty:
-            cv2.imwrite(output_path, processed_img)
-            # 输出图像已保存
+            cv2.imwrite(annotated_output_path, processed_img)
+            # 画框图像已保存
         else:
-            output_path = None
-            # 未保存图像 (无检测且save_empty=False)
+            annotated_output_path = None
+            # 未保存画框图像 (无检测且save_empty=False)
             
         self.image_count += 1
         
@@ -488,7 +501,7 @@ def get_model_type(model_path: str) -> str:
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='UAV Detection using ONNX model')
     parser.add_argument('--input', type=str, help='输入图像路径或目录')
-    parser.add_argument('--threshold', type=float, default=0.5, help='检测置信度阈值')
+    parser.add_argument('--threshold', type=float, default=0.4, help='检测置信度阈值')
     parser.add_argument('--output', type=str, default=None, help='输出目录路径,默认为输入目录名+_results')
     parser.add_argument('--max-bbox-ratio', type=float, default=0.05,
                         help='检测框最大面积比例阈值,默认0.05')

+ 462 - 0
target_based_classifier.py

@@ -0,0 +1,462 @@
+import tkinter as tk
+from tkinter import ttk, filedialog, messagebox
+from PIL import Image, ImageTk, ImageEnhance
+import os
+import shutil
+from pathlib import Path
+import numpy as np
+from collections import Counter
+import cv2
+
+class TargetBasedClassifier:
+    def __init__(self, root):
+        self.root = root
+        self.root.title("基于目标的图片分类器")
+        self.root.geometry("1400x900")
+        
+        # 当前工作目录(包含original、targets、annotated三个文件夹)
+        self.work_directory = None
+        # 三个子目录路径
+        self.original_dir = None
+        self.targets_dir = None
+        self.annotated_dir = None
+        
+        # 图片列表和索引
+        self.image_list = []  # 存储original文件夹中的图片路径
+        self.current_index = 0
+        
+        # 当前显示的图片信息
+        self.current_original_path = None
+        self.current_targets = []  # 当前图片对应的目标切片列表
+        
+        # 图片处理参数
+        self.zoom_factor = 1.0
+        self.rotation_angle = 0
+        self.brightness_factor = 1.0
+        
+        # 操作历史记录
+        self.operation_history = []
+        
+        self.setup_ui()
+        self.setup_keyboard_shortcuts()
+        
+    def setup_ui(self):
+        # 创建主框架
+        main_frame = ttk.Frame(self.root)
+        main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
+        
+        # 左侧控制面板
+        control_frame = ttk.Frame(main_frame)
+        control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5)
+        
+        # 选择工作目录按钮
+        ttk.Button(control_frame, text="选择工作目录 (F)", command=self.select_work_directory).pack(pady=5)
+        
+        # 分类按钮
+        ttk.Button(control_frame, text="无人机 (1)", command=lambda: self.classify_image("drone")).pack(pady=5)
+        ttk.Button(control_frame, text="鸟类 (2)", command=lambda: self.classify_image("bird")).pack(pady=5)
+        ttk.Button(control_frame, text="有人机 (3)", command=lambda: self.classify_image("manned")).pack(pady=5)
+        ttk.Button(control_frame, text="其他 (4)", command=lambda: self.classify_image("other")).pack(pady=5)
+        ttk.Button(control_frame, text="无目标 (5)", command=lambda: self.classify_image("none")).pack(pady=5)
+        
+        # 导航按钮
+        nav_frame = ttk.Frame(control_frame)
+        nav_frame.pack(pady=10)
+        ttk.Button(nav_frame, text="上一张 (←)", command=self.prev_image).pack(side=tk.LEFT, padx=5)
+        ttk.Button(nav_frame, text="下一张 (→)", command=self.next_image).pack(side=tk.LEFT, padx=5)
+        
+        # 撤回按钮
+        ttk.Button(control_frame, text="撤回操作 (6)", command=self.undo_operation).pack(pady=5)
+        
+        # 图片处理控制
+        process_frame = ttk.LabelFrame(control_frame, text="图片处理")
+        process_frame.pack(pady=10, fill=tk.X)
+        
+        # 缩放控制
+        ttk.Label(process_frame, text="缩放:").pack()
+        self.zoom_scale = ttk.Scale(process_frame, from_=0.1, to=3.0, orient=tk.HORIZONTAL, 
+                                  command=self.update_zoom)
+        self.zoom_scale.set(1.0)
+        self.zoom_scale.pack(fill=tk.X, padx=5)
+        
+        # 旋转控制
+        ttk.Label(process_frame, text="旋转:").pack()
+        self.rotation_scale = ttk.Scale(process_frame, from_=0, to=360, orient=tk.HORIZONTAL,
+                                      command=self.update_rotation)
+        self.rotation_scale.set(0)
+        self.rotation_scale.pack(fill=tk.X, padx=5)
+        
+        # 亮度控制
+        ttk.Label(process_frame, text="亮度:").pack()
+        self.brightness_scale = ttk.Scale(process_frame, from_=0.1, to=2.0, orient=tk.HORIZONTAL,
+                                        command=self.update_brightness)
+        self.brightness_scale.set(1.0)
+        self.brightness_scale.pack(fill=tk.X, padx=5)
+        
+        # 重置按钮
+        ttk.Button(process_frame, text="重置图片 (R)", command=self.reset_image).pack(pady=5)
+        
+        # 右侧显示区域
+        display_frame = ttk.Frame(main_frame)
+        display_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
+        
+        # 上方:标注图片显示
+        original_frame = ttk.LabelFrame(display_frame, text="标注图片(带检测框)")
+        original_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 5))
+        
+        self.original_label = ttk.Label(original_frame)
+        self.original_label.pack(fill=tk.BOTH, expand=True)
+        
+        # 下方:目标切片显示(仅作参考)
+        targets_frame = ttk.LabelFrame(display_frame, text="检测到的目标(仅供参考)")
+        targets_frame.pack(fill=tk.X, pady=(5, 0))
+        
+        # 创建可滚动的目标显示区域
+        targets_canvas = tk.Canvas(targets_frame, height=200)
+        targets_scrollbar = ttk.Scrollbar(targets_frame, orient="horizontal", command=targets_canvas.xview)
+        self.targets_scroll_frame = ttk.Frame(targets_canvas)
+        
+        self.targets_scroll_frame.bind(
+            "<Configure>",
+            lambda e: targets_canvas.configure(scrollregion=targets_canvas.bbox("all"))
+        )
+        
+        targets_canvas.create_window((0, 0), window=self.targets_scroll_frame, anchor="nw")
+        targets_canvas.configure(xscrollcommand=targets_scrollbar.set)
+        
+        targets_canvas.pack(side="top", fill="both", expand=True)
+        targets_scrollbar.pack(side="bottom", fill="x")
+        
+        # 状态标签
+        self.status_label = ttk.Label(self.root, text="请选择包含original、targets、annotated文件夹的工作目录(显示标注图片,分类原图)")
+        self.status_label.pack(side=tk.BOTTOM, pady=5)
+        
+        # 图片信息标签
+        self.info_label = ttk.Label(self.root, text="")
+        self.info_label.pack(side=tk.BOTTOM, pady=5)
+        
+    def setup_keyboard_shortcuts(self):
+        self.root.bind('<F5>', lambda e: self.select_work_directory())
+        self.root.bind('<Left>', lambda e: self.prev_image())
+        self.root.bind('<Right>', lambda e: self.next_image())
+        self.root.bind('1', lambda e: self.classify_image("drone"))
+        self.root.bind('2', lambda e: self.classify_image("bird"))
+        self.root.bind('3', lambda e: self.classify_image("manned"))
+        self.root.bind('4', lambda e: self.classify_image("other"))
+        self.root.bind('5', lambda e: self.classify_image("none"))
+        self.root.bind('6', lambda e: self.undo_operation())
+        self.root.bind('r', lambda e: self.reset_image())
+        
+    def update_zoom(self, value):
+        self.zoom_factor = float(value)
+        self.show_current_image()
+        
+    def update_rotation(self, value):
+        self.rotation_angle = float(value)
+        self.show_current_image()
+        
+    def update_brightness(self, value):
+        self.brightness_factor = float(value)
+        self.show_current_image()
+        
+    def reset_image(self):
+        self.zoom_factor = 1.0
+        self.rotation_angle = 0
+        self.brightness_factor = 1.0
+        self.zoom_scale.set(1.0)
+        self.rotation_scale.set(0)
+        self.brightness_scale.set(1.0)
+        self.show_current_image()
+        
+    def select_work_directory(self):
+        """选择包含original、targets、annotated三个文件夹的工作目录"""
+        folder_path = filedialog.askdirectory(title="选择包含original、targets、annotated文件夹的工作目录")
+        if folder_path:
+            # 检查是否包含必要的子文件夹
+            original_path = os.path.join(folder_path, 'original')
+            targets_path = os.path.join(folder_path, 'targets')
+            annotated_path = os.path.join(folder_path, 'annotated')
+            
+            if not all(os.path.exists(path) for path in [original_path, targets_path, annotated_path]):
+                messagebox.showerror("错误", "所选目录必须包含original、targets、annotated三个文件夹")
+                return
+                
+            self.work_directory = folder_path
+            self.original_dir = original_path
+            self.targets_dir = targets_path
+            self.annotated_dir = annotated_path
+            
+            self.load_images()
+            
+    def load_images(self):
+        """加载original文件夹中的图片"""
+        self.image_list = []
+        
+        if not self.original_dir:
+            return
+            
+        for file in os.listdir(self.original_dir):
+            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
+                self.image_list.append(os.path.join(self.original_dir, file))
+        
+        if self.image_list:
+            self.current_index = 0
+            self.show_current_image()
+            self.status_label.config(text=f"已加载 {len(self.image_list)} 张图片")
+        else:
+            messagebox.showinfo("提示", "original文件夹中没有图片")
+            
+    def find_target_images(self, original_filename):
+        """查找与原图对应的目标切片"""
+        if not self.targets_dir:
+            return []
+            
+        base_name = os.path.splitext(original_filename)[0]
+        target_files = []
+        
+        for file in os.listdir(self.targets_dir):
+            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
+                if file.startswith(base_name + '_'):
+                    target_files.append(os.path.join(self.targets_dir, file))
+                    
+        return sorted(target_files)
+        
+    def find_annotated_image(self, original_filename):
+        """查找与原图对应的标注图片(带检测框的图片)"""
+        if not self.annotated_dir:
+            return None
+            
+        base_name = os.path.splitext(original_filename)[0]
+        
+        # 尝试多种可能的文件扩展名
+        for ext in ['.png', '.jpg', '.jpeg', '.bmp', '.gif']:
+            annotated_path = os.path.join(self.annotated_dir, base_name + ext)
+            if os.path.exists(annotated_path):
+                return annotated_path
+                
+        return None
+        
+    def show_current_image(self):
+        """显示当前图片和对应的目标切片"""
+        if 0 <= self.current_index < len(self.image_list):
+            self.current_original_path = self.image_list[self.current_index]
+            
+            # 显示标注图片(带检测框)
+            self.display_original_image()
+            
+            # 显示目标切片
+            original_filename = os.path.basename(self.current_original_path)
+            self.current_targets = self.find_target_images(original_filename)
+            self.display_target_images()
+            
+            # 更新状态信息
+            self.status_label.config(text=f"图片 {self.current_index + 1}/{len(self.image_list)} | 检测到 {len(self.current_targets)} 个目标")
+            
+            # 显示图片信息
+            if os.path.exists(self.current_original_path):
+                file_size = os.path.getsize(self.current_original_path) / 1024
+                image = Image.open(self.current_original_path)
+                width, height = image.size
+                self.info_label.config(text=f"文件名: {original_filename} | 尺寸: {width}x{height} | 大小: {file_size:.1f}KB")
+            
+    def display_original_image(self):
+        """显示标注图片(带检测框),如果没有标注图片则显示原图"""
+        if not self.current_original_path or not os.path.exists(self.current_original_path):
+            self.original_label.config(image='', text="图片不存在")
+            return
+            
+        try:
+            # 首先尝试查找标注图片
+            original_filename = os.path.basename(self.current_original_path)
+            annotated_path = self.find_annotated_image(original_filename)
+            
+            # 如果有标注图片就显示标注图片,否则显示原图
+            if annotated_path and os.path.exists(annotated_path):
+                image = Image.open(annotated_path)
+                display_text = "(显示标注图片)"
+            else:
+                image = Image.open(self.current_original_path)
+                display_text = "(无标注图片,显示原图)"
+            
+            # 应用亮度调整
+            enhancer = ImageEnhance.Brightness(image)
+            image = enhancer.enhance(self.brightness_factor)
+            
+            # 应用旋转
+            image = image.rotate(self.rotation_angle, expand=True)
+            
+            # 计算调整后的尺寸,保持宽高比
+            width, height = image.size
+            max_size = (600, 400)
+            ratio = min(max_size[0]/width, max_size[1]/height)
+            new_size = (int(width*ratio*self.zoom_factor), int(height*ratio*self.zoom_factor))
+            image = image.resize(new_size, Image.Resampling.LANCZOS)
+            
+            # 转换为PhotoImage
+            photo = ImageTk.PhotoImage(image)
+            
+            # 更新图片显示
+            self.original_label.config(image=photo, text="")
+            self.original_label.image = photo  # 保持引用
+            
+            # 更新框架标题以显示当前显示的图片类型
+            parent_frame = self.original_label.master
+            if hasattr(parent_frame, 'config'):
+                parent_frame.config(text=f"标注图片(带检测框){display_text}")
+            
+        except Exception as e:
+            self.original_label.config(image='', text=f"图片加载失败: {str(e)}")
+            
+    def display_target_images(self):
+        """显示目标切片"""
+        # 清空之前的目标显示
+        for widget in self.targets_scroll_frame.winfo_children():
+            widget.destroy()
+            
+        if not self.current_targets:
+            no_target_label = ttk.Label(self.targets_scroll_frame, text="未检测到目标")
+            no_target_label.pack(side=tk.LEFT, padx=10)
+            return
+            
+        # 显示每个目标切片
+        for i, target_path in enumerate(self.current_targets):
+            try:
+                # 加载目标图片
+                target_image = Image.open(target_path)
+                
+                # 调整大小以适应显示
+                target_image.thumbnail((150, 150), Image.Resampling.LANCZOS)
+                target_photo = ImageTk.PhotoImage(target_image)
+                
+                # 创建目标显示框架
+                target_frame = ttk.Frame(self.targets_scroll_frame)
+                target_frame.pack(side=tk.LEFT, padx=5, pady=5)
+                
+                # 显示目标图片
+                target_label = ttk.Label(target_frame, image=target_photo)
+                target_label.image = target_photo  # 保持引用
+                target_label.pack()
+                
+                # 显示目标文件名
+                filename_label = ttk.Label(target_frame, text=os.path.basename(target_path), 
+                                         font=('Arial', 8))
+                filename_label.pack()
+                
+            except Exception as e:
+                error_label = ttk.Label(self.targets_scroll_frame, text=f"目标{i+1}加载失败")
+                error_label.pack(side=tk.LEFT, padx=5)
+                
+    def next_image(self):
+        if self.image_list:
+            self.current_index = (self.current_index + 1) % len(self.image_list)
+            self.show_current_image()
+            
+    def prev_image(self):
+        if self.image_list:
+            self.current_index = (self.current_index - 1) % len(self.image_list)
+            self.show_current_image()
+            
+    def classify_image(self, category):
+        """分类当前图片(仅移动原图)"""
+        if not self.current_original_path:
+            return
+            
+        # 创建分类文件夹
+        category_folders = {
+            "drone": "drone",
+            "bird": "bird",
+            "manned": "manned",
+            "other": "others",
+            "none": "none"
+        }
+        
+        target_folder = os.path.join(self.work_directory, category_folders[category])
+        os.makedirs(target_folder, exist_ok=True)
+        
+        # 只移动原图
+        original_filename = os.path.basename(self.current_original_path)
+        target_path = os.path.join(target_folder, original_filename)
+        
+        try:
+            # 记录操作历史(在移动之前)
+            operation_record = {
+                'action': 'classify',
+                'category': category,
+                'original_source': self.current_original_path,
+                'original_target': target_path,
+                'index': self.current_index,
+                'original_filename': original_filename
+            }
+            
+            # 执行文件移动(仅移动原图)
+            if os.path.exists(self.current_original_path):
+                shutil.move(self.current_original_path, target_path)
+                operation_record['moved'] = True
+            else:
+                operation_record['moved'] = False
+                
+            # 添加到操作历史
+            self.operation_history.append(operation_record)
+            
+            # 限制历史记录数量
+            if len(self.operation_history) > 50:
+                self.operation_history.pop(0)
+                
+            # 从列表中移除当前图片
+            self.image_list.pop(self.current_index)
+            
+            if self.image_list:
+                self.current_index = self.current_index % len(self.image_list)
+                self.show_current_image()
+            else:
+                self.original_label.config(image='', text="所有图片已分类完成")
+                # 清空目标显示
+                for widget in self.targets_scroll_frame.winfo_children():
+                    widget.destroy()
+                self.status_label.config(text="所有图片已分类完成")
+                self.info_label.config(text="")
+                
+        except Exception as e:
+            messagebox.showerror("错误", f"分类操作失败:{str(e)}")
+            
+    def undo_operation(self):
+        """撤回最后一次分类操作"""
+        if not self.operation_history:
+            messagebox.showinfo("提示", "没有可撤回的操作")
+            return
+            
+        # 获取最后一次操作
+        last_operation = self.operation_history.pop()
+        
+        if last_operation['action'] == 'classify' and last_operation.get('moved', False):
+            try:
+                # 将原图移回原位置
+                if os.path.exists(last_operation['original_target']):
+                    shutil.move(last_operation['original_target'], last_operation['original_source'])
+                        
+                # 重新加载图片列表
+                self.load_images()
+                
+                # 尝试定位到撤回的图片
+                original_path = last_operation['original_source']
+                        
+                if original_path and original_path in self.image_list:
+                    self.current_index = self.image_list.index(original_path)
+                else:
+                    self.current_index = 0
+                    
+                if self.image_list:
+                    self.show_current_image()
+                    messagebox.showinfo("成功", f"已撤回对 '{last_operation['original_filename']}' 的分类操作")
+                else:
+                    messagebox.showinfo("提示", "撤回成功,但当前目录中没有图片")
+                    
+            except Exception as e:
+                messagebox.showerror("错误", f"撤回操作失败:{str(e)}")
+                # 如果撤回失败,将操作重新加入历史记录
+                self.operation_history.append(last_operation)
+
+if __name__ == "__main__":
+    root = tk.Tk()
+    app = TargetBasedClassifier(root)
+    root.mainloop()