#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 模型评估运行脚本 简化的入口点,用于快速开始模型评估 使用方法: 1. 如果没有测试数据,先运行: python generate_test_data.py 2. 然后运行: python run_evaluation.py 作者: AI Assistant 创建时间: 2024 """ import os import sys import json import argparse import random import numpy as np from datetime import datetime # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from model_evaluator import ModelEvaluator from generate_test_data import TestDataGenerator def check_data_availability(data_dir: str) -> bool: """ 检查测试数据是否可用 Args: data_dir: 数据目录路径 Returns: 数据是否可用 """ if not os.path.exists(data_dir): return False # 检查各个子目录 subdirs = ['bird', 'drone', 'none'] total_images = 0 for subdir in subdirs: subdir_path = os.path.join(data_dir, subdir) if os.path.exists(subdir_path): # 计算图像数量 image_files = set() # 使用集合避免重复计算 for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']: import glob # 搜索小写扩展名 image_files.update(glob.glob(os.path.join(subdir_path, ext))) # 搜索大写扩展名 image_files.update(glob.glob(os.path.join(subdir_path, ext.upper()))) total_images += len(image_files) return total_images > 0 def load_config(config_file: str = 'evaluation_config.json') -> dict: """ 加载配置文件 Args: config_file: 配置文件路径 Returns: 配置字典 """ if os.path.exists(config_file): with open(config_file, 'r', encoding='utf-8') as f: return json.load(f) else: # 返回默认配置 return { "evaluation_settings": { "model_type": "uav_and_bird", "confidence_threshold": 0.5, "iou_threshold": 0.5, "sample_size": None }, "data_settings": { "data_directory": "Data", "use_folder_structure": True }, "output_settings": { "output_directory": "evaluation_results", "generate_plots": True } } def generate_test_data_if_needed(data_dir: str, force_generate: bool = False) -> bool: """ 如果需要,生成测试数据 Args: data_dir: 数据目录 force_generate: 是否强制重新生成 Returns: 是否成功生成或已存在数据 """ if not check_data_availability(data_dir) or force_generate: print("未找到测试数据,正在生成模拟数据...") generator = TestDataGenerator(data_dir) dataset_info = generator.generate_dataset( num_birds=20, num_uavs=20, num_others=40 ) generator.create_annotation_file(dataset_info) print("测试数据生成完成!") return True else: print("找到现有测试数据") return True def main(): """ 主函数 """ parser = argparse.ArgumentParser(description='模型性能评估工具') parser.add_argument('--config', '-c', default='evaluation_config.json', help='配置文件路径') parser.add_argument('--data-dir', '-d', default='Data', help='测试数据目录') parser.add_argument('--model-type', '-m', default='uav_and_bird', choices=['uav_and_bird', 'Anti_UAV', 'UAV-250411'], help='模型类型') parser.add_argument('--threshold', '-t', type=float, default=0.5, help='置信度阈值') parser.add_argument('--sample-size', '-s', type=int, default=None, help='采样大小(None表示使用全部数据)') parser.add_argument('--balanced-sampling', action='store_true', help='启用平衡采样(按比例抽取样本)') parser.add_argument('--bird-ratio', type=float, default=0.2, help='鸟类样本比例(默认0.2,即20%)') parser.add_argument('--drone-ratio', type=float, default=0.2, help='无人机样本比例(默认0.2,即20%)') parser.add_argument('--none-ratio', type=float, default=0.6, help='其他样本比例(默认0.6,即60%)') parser.add_argument('--generate-data', '-g', action='store_true', help='强制重新生成测试数据') parser.add_argument('--output-dir', '-o', default='evaluation_results', help='输出目录') args = parser.parse_args() # 加载配置 config = load_config(args.config) # 如果命令行没有指定,则从配置文件读取平衡采样设置 if not args.balanced_sampling and config['evaluation_settings'].get('balanced_sampling', {}).get('enabled', False): args.balanced_sampling = True balanced_config = config['evaluation_settings']['balanced_sampling'] args.bird_ratio = balanced_config.get('bird_ratio', 0.2) args.drone_ratio = balanced_config.get('drone_ratio', 0.2) args.none_ratio = balanced_config.get('none_ratio', 0.6) print(f"从配置文件加载平衡采样设置") print("=" * 60) print("UAV和鸟类检测模型性能评估工具") print("=" * 60) print(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"模型类型: {args.model_type}") print(f"置信度阈值: {args.threshold}") print(f"数据目录: {args.data_dir}") print(f"输出目录: {args.output_dir}") print(f"配置文件: {args.config}") # 检查并生成测试数据 data_dir = os.path.abspath(args.data_dir) if not generate_test_data_if_needed(data_dir, args.generate_data): print("错误: 无法生成或找到测试数据") return 1 # 初始化评估器 print("\n初始化模型评估器...") try: evaluator = ModelEvaluator( model_type=args.model_type, threshold=args.threshold, iou_threshold=config['evaluation_settings'].get('iou_threshold', 0.5) ) print("模型加载成功") except Exception as e: print(f"错误: 模型加载失败 - {e}") return 1 # 加载ground truth print("\n加载测试数据...") try: ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir) if not ground_truth: print("错误: 未找到有效的测试数据") return 1 print(f"找到 {len(ground_truth)} 张测试图像") # 显示数据分布 label_counts = {} for item in ground_truth: label = item['label'] label_counts[label] = label_counts.get(label, 0) + 1 print("数据分布:") for label, count in label_counts.items(): print(f" {label.upper()}: {count} 张") except Exception as e: print(f"错误: 加载测试数据失败 - {e}") return 1 # 验证比例参数 if args.balanced_sampling: total_ratio = args.bird_ratio + args.drone_ratio + args.none_ratio if abs(total_ratio - 1.0) > 0.01: # 允许小的浮点误差 print(f"警告: 比例总和为 {total_ratio:.3f},不等于1.0,将自动归一化") # 归一化比例 args.bird_ratio /= total_ratio args.drone_ratio /= total_ratio args.none_ratio /= total_ratio print(f"归一化后比例: 鸟类={args.bird_ratio:.3f}, 无人机={args.drone_ratio:.3f}, 其他={args.none_ratio:.3f}") # 开始评估 print("\n" + "=" * 40) print("开始模型性能评估") if args.balanced_sampling: print(f"采样策略: 平衡采样 (鸟类:{args.bird_ratio*100:.1f}%, 无人机:{args.drone_ratio*100:.1f}%, 其他:{args.none_ratio*100:.1f}%)") else: print("采样策略: 随机采样") print("=" * 40) try: metrics = evaluator.evaluate_dataset( ground_truth, sample_size=args.sample_size, balanced_sampling=args.balanced_sampling, bird_ratio=args.bird_ratio, drone_ratio=args.drone_ratio, none_ratio=args.none_ratio ) print("\n评估完成!") except Exception as e: print(f"错误: 评估过程失败 - {e}") return 1 # 生成报告和图表 print("\n生成评估报告...") try: report_file = evaluator.generate_report(metrics, args.output_dir) if config['output_settings'].get('generate_plots', True): print("生成评估图表...") plot_file = evaluator.plot_results(metrics, args.output_dir) except Exception as e: print(f"警告: 生成报告或图表时出错 - {e}") # 显示关键结果 print("\n" + "=" * 40) print("评估结果摘要") print("=" * 40) pm = metrics['performance_metrics'] sm = metrics['speed_metrics'] cm = metrics['confusion_matrix'] print(f"\n混淆矩阵:") print(f" 真正例 (TP): {cm['true_positives']}") print(f" 假正例 (FP): {cm['false_positives']}") print(f" 假负例 (FN): {cm['false_negatives']}") print(f" 真负例 (TN): {cm['true_negatives']}") print(f"\n性能指标:") print(f" 精确率 (Precision): {pm['precision']:.4f}") print(f" 召回率 (Recall): {pm['recall']:.4f}") print(f" F1分数: {pm['f1_score']:.4f}") print(f" 准确率 (Accuracy): {pm['accuracy']:.4f}") print(f" 误报率 (FPR): {pm['false_positive_rate']:.4f}") print(f" 漏报率 (FNR): {pm['false_negative_rate']:.4f}") print(f"\n速度指标:") print(f" 平均推理时间: {sm['avg_inference_time']:.4f} 秒") print(f" 推理速度: {sm['fps']:.2f} FPS") print(f" 总处理图像: {sm['total_images']} 张") print(f" 总耗时: {sm['total_time']:.2f} 秒") # 按类别显示结果 if metrics['class_metrics']: print(f"\n按类别性能:") for class_name, class_metrics in metrics['class_metrics'].items(): print(f" {class_name.upper()}类别:") print(f" 精确率: {class_metrics['precision']:.4f}") print(f" 召回率: {class_metrics['recall']:.4f}") print(f" F1分数: {class_metrics['f1_score']:.4f}") print(f"\n详细报告已保存到: {report_file}") if 'plot_file' in locals(): print(f"评估图表已保存到: {plot_file}") print("\n评估完成!") return 0 if __name__ == "__main__": exit_code = main() sys.exit(exit_code)