|
@@ -0,0 +1,679 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+模型性能评估工具
|
|
|
+用于评估UAV和鸟类检测模型的各项性能指标
|
|
|
+包括:误报率、漏报率、分类准确率、推理速度等
|
|
|
+
|
|
|
+作者: AI Assistant
|
|
|
+创建时间: 2024
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import json
|
|
|
+import time
|
|
|
+import random
|
|
|
+from datetime import datetime
|
|
|
+from typing import List, Dict, Tuple, Optional
|
|
|
+from glob import glob
|
|
|
+from tqdm import tqdm
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import seaborn as sns
|
|
|
+from collections import defaultdict
|
|
|
+import pandas as pd
|
|
|
+
|
|
|
+# 导入现有的推理模块
|
|
|
+from src.core.inference import ONNXDetector
|
|
|
+
|
|
|
+class ModelEvaluator:
|
|
|
+ """
|
|
|
+ 模型性能评估器
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, model_path: str = None, model_type: str = 'uav_and_bird',
|
|
|
+ threshold: float = 0.5, iou_threshold: float = 0.5):
|
|
|
+ """
|
|
|
+ 初始化评估器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model_path: 模型文件路径
|
|
|
+ model_type: 模型类型
|
|
|
+ threshold: 置信度阈值
|
|
|
+ iou_threshold: IoU阈值
|
|
|
+ """
|
|
|
+ self.model_type = model_type
|
|
|
+ self.threshold = threshold
|
|
|
+ self.iou_threshold = iou_threshold
|
|
|
+
|
|
|
+ # 初始化检测器
|
|
|
+ self.detector = ONNXDetector(
|
|
|
+ threshold=threshold,
|
|
|
+ model_type=model_type,
|
|
|
+ model_path=model_path,
|
|
|
+ save_empty=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # 评估结果存储
|
|
|
+ self.evaluation_results = {
|
|
|
+ 'true_positives': 0,
|
|
|
+ 'false_positives': 0,
|
|
|
+ 'false_negatives': 0,
|
|
|
+ 'true_negatives': 0,
|
|
|
+ 'inference_times': [],
|
|
|
+ 'detailed_results': [],
|
|
|
+ 'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
|
|
|
+ }
|
|
|
+
|
|
|
+ # 类别映射
|
|
|
+ if model_type == 'uav_and_bird':
|
|
|
+ self.class_names = {0: 'Bird', 1: 'Drone'}
|
|
|
+ self.class_mapping = {'bird': 0, 'drone': 1, 'none': -1}
|
|
|
+ else:
|
|
|
+ self.class_names = {0: 'Drone'}
|
|
|
+ self.class_mapping = {'drone': 0, 'none': -1}
|
|
|
+
|
|
|
+ def load_ground_truth_from_folder_structure(self, data_dir: str) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 从文件夹结构加载ground truth
|
|
|
+ 假设文件夹结构为:
|
|
|
+ data_dir/
|
|
|
+ ├── bird/
|
|
|
+ ├── drone/
|
|
|
+ └── none/
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data_dir: 数据目录路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含图像路径和标签的字典列表
|
|
|
+ """
|
|
|
+ ground_truth = []
|
|
|
+
|
|
|
+ # 支持的图像格式
|
|
|
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
|
|
|
+
|
|
|
+ # 遍历各个类别文件夹
|
|
|
+ for class_folder in ['bird', 'drone', 'none']:
|
|
|
+ folder_path = os.path.join(data_dir, class_folder)
|
|
|
+ if not os.path.exists(folder_path):
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 获取该文件夹下的所有图像
|
|
|
+ images = set() # 使用集合避免重复计算
|
|
|
+ for ext in image_extensions:
|
|
|
+ # 搜索小写扩展名
|
|
|
+ images.update(glob(os.path.join(folder_path, ext)))
|
|
|
+ # 搜索大写扩展名
|
|
|
+ images.update(glob(os.path.join(folder_path, ext.upper())))
|
|
|
+
|
|
|
+ # 确定类别标签
|
|
|
+ if class_folder == 'bird':
|
|
|
+ label = 'bird'
|
|
|
+ elif class_folder == 'drone':
|
|
|
+ label = 'drone'
|
|
|
+ else:
|
|
|
+ label = 'none'
|
|
|
+
|
|
|
+ # 添加到ground truth
|
|
|
+ for img_path in images:
|
|
|
+ ground_truth.append({
|
|
|
+ 'image_path': img_path,
|
|
|
+ 'label': label,
|
|
|
+ 'class_id': self.class_mapping.get(label, -1),
|
|
|
+ 'has_target': label != 'none'
|
|
|
+ })
|
|
|
+
|
|
|
+ return ground_truth
|
|
|
+
|
|
|
+ def load_ground_truth_from_annotation(self, annotation_file: str) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 从标注文件加载ground truth
|
|
|
+ 支持COCO格式或自定义JSON格式
|
|
|
+
|
|
|
+ Args:
|
|
|
+ annotation_file: 标注文件路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含图像路径和标签的字典列表
|
|
|
+ """
|
|
|
+ with open(annotation_file, 'r', encoding='utf-8') as f:
|
|
|
+ annotations = json.load(f)
|
|
|
+
|
|
|
+ ground_truth = []
|
|
|
+
|
|
|
+ # 处理自定义格式
|
|
|
+ if 'images' in annotations:
|
|
|
+ for img_info in annotations['images']:
|
|
|
+ ground_truth.append({
|
|
|
+ 'image_path': img_info['path'],
|
|
|
+ 'label': img_info.get('label', 'none'),
|
|
|
+ 'class_id': self.class_mapping.get(img_info.get('label', 'none'), -1),
|
|
|
+ 'has_target': img_info.get('has_target', False),
|
|
|
+ 'bboxes': img_info.get('bboxes', [])
|
|
|
+ })
|
|
|
+
|
|
|
+ return ground_truth
|
|
|
+
|
|
|
+ def calculate_iou(self, box1: List[float], box2: List[float]) -> float:
|
|
|
+ """
|
|
|
+ 计算两个边界框的IoU
|
|
|
+
|
|
|
+ Args:
|
|
|
+ box1: [x1, y1, x2, y2]
|
|
|
+ box2: [x1, y1, x2, y2]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ IoU值
|
|
|
+ """
|
|
|
+ x1 = max(box1[0], box2[0])
|
|
|
+ y1 = max(box1[1], box2[1])
|
|
|
+ x2 = min(box1[2], box2[2])
|
|
|
+ y2 = min(box1[3], box2[3])
|
|
|
+
|
|
|
+ if x2 <= x1 or y2 <= y1:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ intersection = (x2 - x1) * (y2 - y1)
|
|
|
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
|
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
|
+ union = area1 + area2 - intersection
|
|
|
+
|
|
|
+ return intersection / union if union > 0 else 0.0
|
|
|
+
|
|
|
+ def evaluate_single_image(self, image_info: Dict) -> Dict:
|
|
|
+ """
|
|
|
+ 评估单张图像
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image_info: 包含图像路径和ground truth的字典
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 评估结果字典
|
|
|
+ """
|
|
|
+ image_path = image_info['image_path']
|
|
|
+ gt_label = image_info['label']
|
|
|
+ gt_class_id = image_info['class_id']
|
|
|
+ gt_has_target = image_info['has_target']
|
|
|
+ gt_bboxes = image_info.get('bboxes', [])
|
|
|
+
|
|
|
+ # 记录推理时间
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 进行推理
|
|
|
+ try:
|
|
|
+ # 使用detector的方法进行推理
|
|
|
+ input_data = self.detector.preprocess(image_path)
|
|
|
+ if input_data is None:
|
|
|
+ raise Exception("图像预处理失败")
|
|
|
+
|
|
|
+ # 执行推理
|
|
|
+ detections_raw = self.detector.inference(input_data)
|
|
|
+
|
|
|
+ # 后处理获取检测结果
|
|
|
+ image_orig = cv2.imread(image_path)
|
|
|
+ detection_count, processed_img, detections_list = self.detector.postprocess(
|
|
|
+ detections_raw, image_orig, image_path
|
|
|
+ )
|
|
|
+
|
|
|
+ inference_time = time.time() - start_time
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"推理失败: {image_path}, 错误: {e}")
|
|
|
+ return {
|
|
|
+ 'image_path': image_path,
|
|
|
+ 'gt_label': gt_label,
|
|
|
+ 'prediction': 'error',
|
|
|
+ 'inference_time': 0,
|
|
|
+ 'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # 分析结果
|
|
|
+ has_detection = detection_count > 0
|
|
|
+
|
|
|
+ # 计算混淆矩阵元素
|
|
|
+ tp = fp = fn = tn = 0
|
|
|
+
|
|
|
+ if gt_has_target and has_detection:
|
|
|
+ tp = 1 # 有目标且检测到
|
|
|
+ elif gt_has_target and not has_detection:
|
|
|
+ fn = 1 # 有目标但未检测到(漏报)
|
|
|
+ elif not gt_has_target and has_detection:
|
|
|
+ fp = 1 # 无目标但检测到(误报)
|
|
|
+ elif not gt_has_target and not has_detection:
|
|
|
+ tn = 1 # 无目标且未检测到
|
|
|
+
|
|
|
+ # 详细结果
|
|
|
+ result = {
|
|
|
+ 'image_path': image_path,
|
|
|
+ 'gt_label': gt_label,
|
|
|
+ 'gt_has_target': gt_has_target,
|
|
|
+ 'detection_count': detection_count,
|
|
|
+ 'has_detection': has_detection,
|
|
|
+ 'inference_time': inference_time,
|
|
|
+ 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn,
|
|
|
+ 'detections': detections_list if 'detections_list' in locals() else []
|
|
|
+ }
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ def balanced_sampling(self, ground_truth: List[Dict], total_samples: int,
|
|
|
+ bird_ratio: float = 0.2, drone_ratio: float = 0.2,
|
|
|
+ none_ratio: float = 0.6) -> List[Dict]:
|
|
|
+ """
|
|
|
+ 按比例平衡采样
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ground_truth: 原始ground truth列表
|
|
|
+ total_samples: 总采样数量
|
|
|
+ bird_ratio: 鸟类样本比例
|
|
|
+ drone_ratio: 无人机样本比例
|
|
|
+ none_ratio: 无目标样本比例
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 平衡采样后的ground truth列表
|
|
|
+ """
|
|
|
+ # 按类别分组
|
|
|
+ birds = [item for item in ground_truth if item['label'] == 'bird']
|
|
|
+ drones = [item for item in ground_truth if item['label'] == 'drone']
|
|
|
+ nones = [item for item in ground_truth if item['label'] == 'none']
|
|
|
+
|
|
|
+ # 计算各类别需要的样本数
|
|
|
+ bird_count = int(total_samples * bird_ratio)
|
|
|
+ drone_count = int(total_samples * drone_ratio)
|
|
|
+ none_count = int(total_samples * none_ratio)
|
|
|
+
|
|
|
+ # 确保不超过实际可用样本数
|
|
|
+ bird_count = min(bird_count, len(birds))
|
|
|
+ drone_count = min(drone_count, len(drones))
|
|
|
+ none_count = min(none_count, len(nones))
|
|
|
+
|
|
|
+ print(f"平衡采样策略:")
|
|
|
+ print(f" 鸟类: {bird_count}/{len(birds)} ({bird_ratio*100:.1f}%)")
|
|
|
+ print(f" 无人机: {drone_count}/{len(drones)} ({drone_ratio*100:.1f}%)")
|
|
|
+ print(f" 无目标: {none_count}/{len(nones)} ({none_ratio*100:.1f}%)")
|
|
|
+ print(f" 总计: {bird_count + drone_count + none_count} 张图像")
|
|
|
+
|
|
|
+ # 随机采样各类别
|
|
|
+ sampled_data = []
|
|
|
+ if bird_count > 0:
|
|
|
+ sampled_data.extend(random.sample(birds, bird_count))
|
|
|
+ if drone_count > 0:
|
|
|
+ sampled_data.extend(random.sample(drones, drone_count))
|
|
|
+ if none_count > 0:
|
|
|
+ sampled_data.extend(random.sample(nones, none_count))
|
|
|
+
|
|
|
+ # 打乱顺序
|
|
|
+ random.shuffle(sampled_data)
|
|
|
+
|
|
|
+ return sampled_data
|
|
|
+
|
|
|
+ def evaluate_dataset(self, ground_truth: List[Dict], sample_size: Optional[int] = None,
|
|
|
+ balanced_sampling: bool = False, bird_ratio: float = 0.2,
|
|
|
+ drone_ratio: float = 0.2, none_ratio: float = 0.6) -> Dict:
|
|
|
+ """
|
|
|
+ 评估整个数据集
|
|
|
+
|
|
|
+ Args:
|
|
|
+ ground_truth: ground truth列表
|
|
|
+ sample_size: 采样大小,None表示使用全部数据
|
|
|
+ balanced_sampling: 是否使用平衡采样
|
|
|
+ bird_ratio: 鸟类样本比例(仅在balanced_sampling=True时有效)
|
|
|
+ drone_ratio: 无人机样本比例(仅在balanced_sampling=True时有效)
|
|
|
+ none_ratio: 无目标样本比例(仅在balanced_sampling=True时有效)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 评估结果字典
|
|
|
+ """
|
|
|
+ # 采样策略
|
|
|
+ if sample_size and sample_size < len(ground_truth):
|
|
|
+ if balanced_sampling:
|
|
|
+ # 平衡采样
|
|
|
+ ground_truth = self.balanced_sampling(
|
|
|
+ ground_truth, sample_size, bird_ratio, drone_ratio, none_ratio
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # 随机采样
|
|
|
+ ground_truth = random.sample(ground_truth, sample_size)
|
|
|
+ print(f"随机采样 {sample_size} 张图像进行评估")
|
|
|
+ elif balanced_sampling:
|
|
|
+ print("警告: 启用了平衡采样但未指定sample_size,将使用全部数据")
|
|
|
+
|
|
|
+ print(f"开始评估 {len(ground_truth)} 张图像...")
|
|
|
+
|
|
|
+ # 重置结果
|
|
|
+ self.evaluation_results = {
|
|
|
+ 'true_positives': 0,
|
|
|
+ 'false_positives': 0,
|
|
|
+ 'false_negatives': 0,
|
|
|
+ 'true_negatives': 0,
|
|
|
+ 'inference_times': [],
|
|
|
+ 'detailed_results': [],
|
|
|
+ 'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
|
|
|
+ }
|
|
|
+
|
|
|
+ # 逐张评估
|
|
|
+ for img_info in tqdm(ground_truth, desc="评估进度"):
|
|
|
+ result = self.evaluate_single_image(img_info)
|
|
|
+
|
|
|
+ # 累计统计
|
|
|
+ self.evaluation_results['true_positives'] += result['tp']
|
|
|
+ self.evaluation_results['false_positives'] += result['fp']
|
|
|
+ self.evaluation_results['false_negatives'] += result['fn']
|
|
|
+ self.evaluation_results['true_negatives'] += result['tn']
|
|
|
+ self.evaluation_results['inference_times'].append(result['inference_time'])
|
|
|
+ self.evaluation_results['detailed_results'].append(result)
|
|
|
+
|
|
|
+ # 按类别统计
|
|
|
+ gt_label = result['gt_label']
|
|
|
+ # 统计所有类别的指标,包括none类别
|
|
|
+ self.evaluation_results['class_metrics'][gt_label]['tp'] += result['tp']
|
|
|
+ self.evaluation_results['class_metrics'][gt_label]['fp'] += result['fp']
|
|
|
+ self.evaluation_results['class_metrics'][gt_label]['fn'] += result['fn']
|
|
|
+
|
|
|
+ return self.calculate_metrics()
|
|
|
+
|
|
|
+ def calculate_metrics(self) -> Dict:
|
|
|
+ """
|
|
|
+ 计算各种性能指标
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 性能指标字典
|
|
|
+ """
|
|
|
+ tp = self.evaluation_results['true_positives']
|
|
|
+ fp = self.evaluation_results['false_positives']
|
|
|
+ fn = self.evaluation_results['false_negatives']
|
|
|
+ tn = self.evaluation_results['true_negatives']
|
|
|
+
|
|
|
+ # 基本指标
|
|
|
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
|
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
|
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
+ accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0
|
|
|
+
|
|
|
+ # 误报率和漏报率
|
|
|
+ false_positive_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
|
|
|
+ false_negative_rate = fn / (fn + tp) if (fn + tp) > 0 else 0
|
|
|
+
|
|
|
+ # 推理速度统计
|
|
|
+ inference_times = self.evaluation_results['inference_times']
|
|
|
+ avg_inference_time = np.mean(inference_times) if inference_times else 0
|
|
|
+ fps = 1 / avg_inference_time if avg_inference_time > 0 else 0
|
|
|
+
|
|
|
+ metrics = {
|
|
|
+ 'confusion_matrix': {
|
|
|
+ 'true_positives': tp,
|
|
|
+ 'false_positives': fp,
|
|
|
+ 'false_negatives': fn,
|
|
|
+ 'true_negatives': tn
|
|
|
+ },
|
|
|
+ 'performance_metrics': {
|
|
|
+ 'precision': precision,
|
|
|
+ 'recall': recall,
|
|
|
+ 'f1_score': f1_score,
|
|
|
+ 'accuracy': accuracy,
|
|
|
+ 'false_positive_rate': false_positive_rate,
|
|
|
+ 'false_negative_rate': false_negative_rate
|
|
|
+ },
|
|
|
+ 'speed_metrics': {
|
|
|
+ 'avg_inference_time': avg_inference_time,
|
|
|
+ 'fps': fps,
|
|
|
+ 'total_images': len(inference_times),
|
|
|
+ 'total_time': sum(inference_times)
|
|
|
+ },
|
|
|
+ 'class_metrics': {}
|
|
|
+ }
|
|
|
+
|
|
|
+ # 按类别计算指标
|
|
|
+ for class_name, class_stats in self.evaluation_results['class_metrics'].items():
|
|
|
+ class_tp = class_stats['tp']
|
|
|
+ class_fp = class_stats['fp']
|
|
|
+ class_fn = class_stats['fn']
|
|
|
+
|
|
|
+ class_precision = class_tp / (class_tp + class_fp) if (class_tp + class_fp) > 0 else 0
|
|
|
+ class_recall = class_tp / (class_tp + class_fn) if (class_tp + class_fn) > 0 else 0
|
|
|
+ class_f1 = 2 * (class_precision * class_recall) / (class_precision + class_recall) if (class_precision + class_recall) > 0 else 0
|
|
|
+
|
|
|
+ metrics['class_metrics'][class_name] = {
|
|
|
+ 'precision': class_precision,
|
|
|
+ 'recall': class_recall,
|
|
|
+ 'f1_score': class_f1,
|
|
|
+ 'tp': class_tp,
|
|
|
+ 'fp': class_fp,
|
|
|
+ 'fn': class_fn
|
|
|
+ }
|
|
|
+
|
|
|
+ return metrics
|
|
|
+
|
|
|
+ def generate_report(self, metrics: Dict, output_dir: str = None) -> str:
|
|
|
+ """
|
|
|
+ 生成评估报告
|
|
|
+
|
|
|
+ Args:
|
|
|
+ metrics: 性能指标字典
|
|
|
+ output_dir: 输出目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 报告文件路径
|
|
|
+ """
|
|
|
+ if output_dir is None:
|
|
|
+ output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
+ report_file = os.path.join(output_dir, f'evaluation_report_{timestamp}.txt')
|
|
|
+
|
|
|
+ with open(report_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("=" * 60 + "\n")
|
|
|
+ f.write("模型性能评估报告\n")
|
|
|
+ f.write("=" * 60 + "\n")
|
|
|
+ f.write(f"评估时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
|
+ f.write(f"模型类型: {self.model_type}\n")
|
|
|
+ f.write(f"置信度阈值: {self.threshold}\n")
|
|
|
+ f.write(f"IoU阈值: {self.iou_threshold}\n\n")
|
|
|
+
|
|
|
+ # 混淆矩阵
|
|
|
+ cm = metrics['confusion_matrix']
|
|
|
+ f.write("混淆矩阵:\n")
|
|
|
+ f.write("-" * 30 + "\n")
|
|
|
+ f.write(f"真正例 (TP): {cm['true_positives']}\n")
|
|
|
+ f.write(f"假正例 (FP): {cm['false_positives']}\n")
|
|
|
+ f.write(f"假负例 (FN): {cm['false_negatives']}\n")
|
|
|
+ f.write(f"真负例 (TN): {cm['true_negatives']}\n\n")
|
|
|
+
|
|
|
+ # 性能指标
|
|
|
+ pm = metrics['performance_metrics']
|
|
|
+ f.write("性能指标:\n")
|
|
|
+ f.write("-" * 30 + "\n")
|
|
|
+ f.write(f"精确率 (Precision): {pm['precision']:.4f}\n")
|
|
|
+ f.write(f"召回率 (Recall): {pm['recall']:.4f}\n")
|
|
|
+ f.write(f"F1分数: {pm['f1_score']:.4f}\n")
|
|
|
+ f.write(f"准确率 (Accuracy): {pm['accuracy']:.4f}\n")
|
|
|
+ f.write(f"误报率 (FPR): {pm['false_positive_rate']:.4f}\n")
|
|
|
+ f.write(f"漏报率 (FNR): {pm['false_negative_rate']:.4f}\n\n")
|
|
|
+
|
|
|
+ # 速度指标
|
|
|
+ sm = metrics['speed_metrics']
|
|
|
+ f.write("速度指标:\n")
|
|
|
+ f.write("-" * 30 + "\n")
|
|
|
+ f.write(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒\n")
|
|
|
+ f.write(f"推理速度 (FPS): {sm['fps']:.2f}\n")
|
|
|
+ f.write(f"总图像数: {sm['total_images']}\n")
|
|
|
+ f.write(f"总耗时: {sm['total_time']:.2f} 秒\n\n")
|
|
|
+
|
|
|
+ # 按类别指标
|
|
|
+ if metrics['class_metrics']:
|
|
|
+ f.write("按类别性能指标:\n")
|
|
|
+ f.write("-" * 30 + "\n")
|
|
|
+ for class_name, class_metrics in metrics['class_metrics'].items():
|
|
|
+ f.write(f"\n{class_name.upper()}类别:\n")
|
|
|
+ f.write(f" 精确率: {class_metrics['precision']:.4f}\n")
|
|
|
+ f.write(f" 召回率: {class_metrics['recall']:.4f}\n")
|
|
|
+ f.write(f" F1分数: {class_metrics['f1_score']:.4f}\n")
|
|
|
+ f.write(f" TP: {class_metrics['tp']}, FP: {class_metrics['fp']}, FN: {class_metrics['fn']}\n")
|
|
|
+
|
|
|
+ print(f"评估报告已保存到: {report_file}")
|
|
|
+ return report_file
|
|
|
+
|
|
|
+ def plot_results(self, metrics: Dict, output_dir: str = None):
|
|
|
+ """
|
|
|
+ 绘制评估结果图表
|
|
|
+
|
|
|
+ Args:
|
|
|
+ metrics: 性能指标字典
|
|
|
+ output_dir: 输出目录
|
|
|
+ """
|
|
|
+ if output_dir is None:
|
|
|
+ output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
+
|
|
|
+ # 设置中文字体
|
|
|
+ plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
|
|
+ plt.rcParams['axes.unicode_minus'] = False
|
|
|
+
|
|
|
+ # 1. 混淆矩阵热力图
|
|
|
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
|
|
|
+
|
|
|
+ cm = metrics['confusion_matrix']
|
|
|
+ confusion_data = np.array([[cm['true_positives'], cm['false_negatives']],
|
|
|
+ [cm['false_positives'], cm['true_negatives']]])
|
|
|
+
|
|
|
+ sns.heatmap(confusion_data, annot=True, fmt='d', cmap='Blues',
|
|
|
+ xticklabels=['预测正例', '预测负例'],
|
|
|
+ yticklabels=['实际正例', '实际负例'], ax=ax1)
|
|
|
+ ax1.set_title('混淆矩阵')
|
|
|
+
|
|
|
+ # 2. 性能指标柱状图
|
|
|
+ pm = metrics['performance_metrics']
|
|
|
+ metrics_names = ['精确率', '召回率', 'F1分数', '准确率']
|
|
|
+ metrics_values = [pm['precision'], pm['recall'], pm['f1_score'], pm['accuracy']]
|
|
|
+
|
|
|
+ bars = ax2.bar(metrics_names, metrics_values, color=['skyblue', 'lightgreen', 'orange', 'pink'])
|
|
|
+ ax2.set_title('性能指标')
|
|
|
+ ax2.set_ylim(0, 1)
|
|
|
+ ax2.set_ylabel('分数')
|
|
|
+
|
|
|
+ # 在柱状图上添加数值标签
|
|
|
+ for bar, value in zip(bars, metrics_values):
|
|
|
+ ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
|
+ f'{value:.3f}', ha='center', va='bottom')
|
|
|
+
|
|
|
+ # 3. 推理时间分布直方图
|
|
|
+ inference_times = self.evaluation_results['inference_times']
|
|
|
+ if inference_times:
|
|
|
+ ax3.hist(inference_times, bins=30, color='lightcoral', alpha=0.7)
|
|
|
+ ax3.set_title('推理时间分布')
|
|
|
+ ax3.set_xlabel('推理时间 (秒)')
|
|
|
+ ax3.set_ylabel('频次')
|
|
|
+ ax3.axvline(np.mean(inference_times), color='red', linestyle='--',
|
|
|
+ label=f'平均值: {np.mean(inference_times):.3f}s')
|
|
|
+ ax3.legend()
|
|
|
+
|
|
|
+ # 4. 按类别性能对比
|
|
|
+ if metrics['class_metrics']:
|
|
|
+ class_names = list(metrics['class_metrics'].keys())
|
|
|
+ precisions = [metrics['class_metrics'][name]['precision'] for name in class_names]
|
|
|
+ recalls = [metrics['class_metrics'][name]['recall'] for name in class_names]
|
|
|
+ f1_scores = [metrics['class_metrics'][name]['f1_score'] for name in class_names]
|
|
|
+
|
|
|
+ x = np.arange(len(class_names))
|
|
|
+ width = 0.25
|
|
|
+
|
|
|
+ ax4.bar(x - width, precisions, width, label='精确率', color='skyblue')
|
|
|
+ ax4.bar(x, recalls, width, label='召回率', color='lightgreen')
|
|
|
+ ax4.bar(x + width, f1_scores, width, label='F1分数', color='orange')
|
|
|
+
|
|
|
+ ax4.set_title('按类别性能对比')
|
|
|
+ ax4.set_xlabel('类别')
|
|
|
+ ax4.set_ylabel('分数')
|
|
|
+ ax4.set_xticks(x)
|
|
|
+ ax4.set_xticklabels(class_names)
|
|
|
+ ax4.legend()
|
|
|
+ ax4.set_ylim(0, 1)
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+ plot_file = os.path.join(output_dir, f'evaluation_plots_{timestamp}.png')
|
|
|
+ plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ print(f"评估图表已保存到: {plot_file}")
|
|
|
+ return plot_file
|
|
|
+
|
|
|
+def main():
|
|
|
+ """
|
|
|
+ 主函数 - 演示如何使用评估器
|
|
|
+ """
|
|
|
+ print("模型性能评估工具")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ # 配置参数
|
|
|
+ data_dir = r"d:\PythonProject\R360-UAVmodelTool\Data"
|
|
|
+ model_type = 'uav_and_bird'
|
|
|
+ threshold = 0.5
|
|
|
+ sample_size = None # None表示使用全部数据,可以设置为具体数字进行采样
|
|
|
+
|
|
|
+ # 初始化评估器
|
|
|
+ evaluator = ModelEvaluator(
|
|
|
+ model_type=model_type,
|
|
|
+ threshold=threshold
|
|
|
+ )
|
|
|
+
|
|
|
+ # 加载ground truth
|
|
|
+ print("加载ground truth数据...")
|
|
|
+ ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir)
|
|
|
+
|
|
|
+ if not ground_truth:
|
|
|
+ print(f"警告: 在 {data_dir} 中未找到测试数据")
|
|
|
+ print("请确保数据目录结构如下:")
|
|
|
+ print("Data/")
|
|
|
+ print("├── BIRDS/ # 鸟类图像")
|
|
|
+ print("├── UAV/ # 无人机图像")
|
|
|
+ print("└── OTHERS/ # 其他图像(负样本)")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f"找到 {len(ground_truth)} 张测试图像")
|
|
|
+
|
|
|
+ # 显示数据分布
|
|
|
+ label_counts = {}
|
|
|
+ for item in ground_truth:
|
|
|
+ label = item['label']
|
|
|
+ label_counts[label] = label_counts.get(label, 0) + 1
|
|
|
+
|
|
|
+ print("数据分布:")
|
|
|
+ for label, count in label_counts.items():
|
|
|
+ print(f" {label}: {count} 张")
|
|
|
+
|
|
|
+ # 开始评估
|
|
|
+ print("\n开始模型评估...")
|
|
|
+ metrics = evaluator.evaluate_dataset(ground_truth, sample_size)
|
|
|
+
|
|
|
+ # 生成报告
|
|
|
+ print("\n生成评估报告...")
|
|
|
+ report_file = evaluator.generate_report(metrics)
|
|
|
+
|
|
|
+ # 绘制图表
|
|
|
+ print("生成评估图表...")
|
|
|
+ plot_file = evaluator.plot_results(metrics)
|
|
|
+
|
|
|
+ # 打印关键指标
|
|
|
+ print("\n=== 评估结果摘要 ===")
|
|
|
+ pm = metrics['performance_metrics']
|
|
|
+ sm = metrics['speed_metrics']
|
|
|
+
|
|
|
+ print(f"精确率: {pm['precision']:.4f}")
|
|
|
+ print(f"召回率: {pm['recall']:.4f}")
|
|
|
+ print(f"F1分数: {pm['f1_score']:.4f}")
|
|
|
+ print(f"准确率: {pm['accuracy']:.4f}")
|
|
|
+ print(f"误报率: {pm['false_positive_rate']:.4f}")
|
|
|
+ print(f"漏报率: {pm['false_negative_rate']:.4f}")
|
|
|
+ print(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒")
|
|
|
+ print(f"推理速度: {sm['fps']:.2f} FPS")
|
|
|
+
|
|
|
+ print(f"\n详细报告: {report_file}")
|
|
|
+ print(f"评估图表: {plot_file}")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|