run_evaluation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型评估运行脚本
  5. 简化的入口点,用于快速开始模型评估
  6. 使用方法:
  7. 1. 如果没有测试数据,先运行: python generate_test_data.py
  8. 2. 然后运行: python run_evaluation.py
  9. 作者: AI Assistant
  10. 创建时间: 2024
  11. """
  12. import os
  13. import sys
  14. import json
  15. import argparse
  16. import random
  17. import numpy as np
  18. from datetime import datetime
  19. # 添加项目根目录到Python路径
  20. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  21. from model_evaluator import ModelEvaluator
  22. from generate_test_data import TestDataGenerator
  23. def check_data_availability(data_dir: str) -> bool:
  24. """
  25. 检查测试数据是否可用
  26. Args:
  27. data_dir: 数据目录路径
  28. Returns:
  29. 数据是否可用
  30. """
  31. if not os.path.exists(data_dir):
  32. return False
  33. # 检查各个子目录
  34. subdirs = ['bird', 'drone', 'none']
  35. total_images = 0
  36. for subdir in subdirs:
  37. subdir_path = os.path.join(data_dir, subdir)
  38. if os.path.exists(subdir_path):
  39. # 计算图像数量
  40. image_files = set() # 使用集合避免重复计算
  41. for ext in ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']:
  42. import glob
  43. # 搜索小写扩展名
  44. image_files.update(glob.glob(os.path.join(subdir_path, ext)))
  45. # 搜索大写扩展名
  46. image_files.update(glob.glob(os.path.join(subdir_path, ext.upper())))
  47. total_images += len(image_files)
  48. return total_images > 0
  49. def load_config(config_file: str = 'evaluation_config.json') -> dict:
  50. """
  51. 加载配置文件
  52. Args:
  53. config_file: 配置文件路径
  54. Returns:
  55. 配置字典
  56. """
  57. if os.path.exists(config_file):
  58. with open(config_file, 'r', encoding='utf-8') as f:
  59. return json.load(f)
  60. else:
  61. # 返回默认配置
  62. return {
  63. "evaluation_settings": {
  64. "model_type": "uav_and_bird",
  65. "confidence_threshold": 0.5,
  66. "iou_threshold": 0.5,
  67. "sample_size": None
  68. },
  69. "data_settings": {
  70. "data_directory": "Data",
  71. "use_folder_structure": True
  72. },
  73. "output_settings": {
  74. "output_directory": "evaluation_results",
  75. "generate_plots": True
  76. }
  77. }
  78. def generate_test_data_if_needed(data_dir: str, force_generate: bool = False) -> bool:
  79. """
  80. 如果需要,生成测试数据
  81. Args:
  82. data_dir: 数据目录
  83. force_generate: 是否强制重新生成
  84. Returns:
  85. 是否成功生成或已存在数据
  86. """
  87. if not check_data_availability(data_dir) or force_generate:
  88. print("未找到测试数据,正在生成模拟数据...")
  89. generator = TestDataGenerator(data_dir)
  90. dataset_info = generator.generate_dataset(
  91. num_birds=20,
  92. num_uavs=20,
  93. num_others=40
  94. )
  95. generator.create_annotation_file(dataset_info)
  96. print("测试数据生成完成!")
  97. return True
  98. else:
  99. print("找到现有测试数据")
  100. return True
  101. def main():
  102. """
  103. 主函数
  104. """
  105. parser = argparse.ArgumentParser(description='模型性能评估工具')
  106. parser.add_argument('--config', '-c', default='evaluation_config.json',
  107. help='配置文件路径')
  108. parser.add_argument('--data-dir', '-d', default='Data',
  109. help='测试数据目录')
  110. parser.add_argument('--model-type', '-m', default='uav_and_bird',
  111. choices=['uav_and_bird', 'Anti_UAV', 'UAV-250411'],
  112. help='模型类型')
  113. parser.add_argument('--threshold', '-t', type=float, default=0.5,
  114. help='置信度阈值')
  115. parser.add_argument('--sample-size', '-s', type=int, default=None,
  116. help='采样大小(None表示使用全部数据)')
  117. parser.add_argument('--balanced-sampling', action='store_true',
  118. help='启用平衡采样(按比例抽取样本)')
  119. parser.add_argument('--bird-ratio', type=float, default=0.2,
  120. help='鸟类样本比例(默认0.2,即20%)')
  121. parser.add_argument('--drone-ratio', type=float, default=0.2,
  122. help='无人机样本比例(默认0.2,即20%)')
  123. parser.add_argument('--none-ratio', type=float, default=0.6,
  124. help='其他样本比例(默认0.6,即60%)')
  125. parser.add_argument('--generate-data', '-g', action='store_true',
  126. help='强制重新生成测试数据')
  127. parser.add_argument('--output-dir', '-o', default='evaluation_results',
  128. help='输出目录')
  129. args = parser.parse_args()
  130. # 加载配置
  131. config = load_config(args.config)
  132. # 如果命令行没有指定,则从配置文件读取平衡采样设置
  133. if not args.balanced_sampling and config['evaluation_settings'].get('balanced_sampling', {}).get('enabled', False):
  134. args.balanced_sampling = True
  135. balanced_config = config['evaluation_settings']['balanced_sampling']
  136. args.bird_ratio = balanced_config.get('bird_ratio', 0.2)
  137. args.drone_ratio = balanced_config.get('drone_ratio', 0.2)
  138. args.none_ratio = balanced_config.get('none_ratio', 0.6)
  139. print(f"从配置文件加载平衡采样设置")
  140. print("=" * 60)
  141. print("UAV和鸟类检测模型性能评估工具")
  142. print("=" * 60)
  143. print(f"启动时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
  144. print(f"模型类型: {args.model_type}")
  145. print(f"置信度阈值: {args.threshold}")
  146. print(f"数据目录: {args.data_dir}")
  147. print(f"输出目录: {args.output_dir}")
  148. print(f"配置文件: {args.config}")
  149. # 检查并生成测试数据
  150. data_dir = os.path.abspath(args.data_dir)
  151. if not generate_test_data_if_needed(data_dir, args.generate_data):
  152. print("错误: 无法生成或找到测试数据")
  153. return 1
  154. # 初始化评估器
  155. print("\n初始化模型评估器...")
  156. try:
  157. evaluator = ModelEvaluator(
  158. model_type=args.model_type,
  159. threshold=args.threshold,
  160. iou_threshold=config['evaluation_settings'].get('iou_threshold', 0.5)
  161. )
  162. print("模型加载成功")
  163. except Exception as e:
  164. print(f"错误: 模型加载失败 - {e}")
  165. return 1
  166. # 加载ground truth
  167. print("\n加载测试数据...")
  168. try:
  169. ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir)
  170. if not ground_truth:
  171. print("错误: 未找到有效的测试数据")
  172. return 1
  173. print(f"找到 {len(ground_truth)} 张测试图像")
  174. # 显示数据分布
  175. label_counts = {}
  176. for item in ground_truth:
  177. label = item['label']
  178. label_counts[label] = label_counts.get(label, 0) + 1
  179. print("数据分布:")
  180. for label, count in label_counts.items():
  181. print(f" {label.upper()}: {count} 张")
  182. except Exception as e:
  183. print(f"错误: 加载测试数据失败 - {e}")
  184. return 1
  185. # 验证比例参数
  186. if args.balanced_sampling:
  187. total_ratio = args.bird_ratio + args.drone_ratio + args.none_ratio
  188. if abs(total_ratio - 1.0) > 0.01: # 允许小的浮点误差
  189. print(f"警告: 比例总和为 {total_ratio:.3f},不等于1.0,将自动归一化")
  190. # 归一化比例
  191. args.bird_ratio /= total_ratio
  192. args.drone_ratio /= total_ratio
  193. args.none_ratio /= total_ratio
  194. print(f"归一化后比例: 鸟类={args.bird_ratio:.3f}, 无人机={args.drone_ratio:.3f}, 其他={args.none_ratio:.3f}")
  195. # 开始评估
  196. print("\n" + "=" * 40)
  197. print("开始模型性能评估")
  198. if args.balanced_sampling:
  199. print(f"采样策略: 平衡采样 (鸟类:{args.bird_ratio*100:.1f}%, 无人机:{args.drone_ratio*100:.1f}%, 其他:{args.none_ratio*100:.1f}%)")
  200. else:
  201. print("采样策略: 随机采样")
  202. print("=" * 40)
  203. try:
  204. metrics = evaluator.evaluate_dataset(
  205. ground_truth,
  206. sample_size=args.sample_size,
  207. balanced_sampling=args.balanced_sampling,
  208. bird_ratio=args.bird_ratio,
  209. drone_ratio=args.drone_ratio,
  210. none_ratio=args.none_ratio
  211. )
  212. print("\n评估完成!")
  213. except Exception as e:
  214. print(f"错误: 评估过程失败 - {e}")
  215. return 1
  216. # 生成报告和图表
  217. print("\n生成评估报告...")
  218. try:
  219. report_file = evaluator.generate_report(metrics, args.output_dir)
  220. if config['output_settings'].get('generate_plots', True):
  221. print("生成评估图表...")
  222. plot_file = evaluator.plot_results(metrics, args.output_dir)
  223. except Exception as e:
  224. print(f"警告: 生成报告或图表时出错 - {e}")
  225. # 显示关键结果
  226. print("\n" + "=" * 40)
  227. print("评估结果摘要")
  228. print("=" * 40)
  229. pm = metrics['performance_metrics']
  230. sm = metrics['speed_metrics']
  231. cm = metrics['confusion_matrix']
  232. print(f"\n混淆矩阵:")
  233. print(f" 真正例 (TP): {cm['true_positives']}")
  234. print(f" 假正例 (FP): {cm['false_positives']}")
  235. print(f" 假负例 (FN): {cm['false_negatives']}")
  236. print(f" 真负例 (TN): {cm['true_negatives']}")
  237. print(f"\n性能指标:")
  238. print(f" 精确率 (Precision): {pm['precision']:.4f}")
  239. print(f" 召回率 (Recall): {pm['recall']:.4f}")
  240. print(f" F1分数: {pm['f1_score']:.4f}")
  241. print(f" 准确率 (Accuracy): {pm['accuracy']:.4f}")
  242. print(f" 误报率 (FPR): {pm['false_positive_rate']:.4f}")
  243. print(f" 漏报率 (FNR): {pm['false_negative_rate']:.4f}")
  244. print(f"\n速度指标:")
  245. print(f" 平均推理时间: {sm['avg_inference_time']:.4f} 秒")
  246. print(f" 推理速度: {sm['fps']:.2f} FPS")
  247. print(f" 总处理图像: {sm['total_images']} 张")
  248. print(f" 总耗时: {sm['total_time']:.2f} 秒")
  249. # 按类别显示结果
  250. if metrics['class_metrics']:
  251. print(f"\n按类别性能:")
  252. for class_name, class_metrics in metrics['class_metrics'].items():
  253. print(f" {class_name.upper()}类别:")
  254. print(f" 精确率: {class_metrics['precision']:.4f}")
  255. print(f" 召回率: {class_metrics['recall']:.4f}")
  256. print(f" F1分数: {class_metrics['f1_score']:.4f}")
  257. print(f"\n详细报告已保存到: {report_file}")
  258. if 'plot_file' in locals():
  259. print(f"评估图表已保存到: {plot_file}")
  260. print("\n评估完成!")
  261. return 0
  262. if __name__ == "__main__":
  263. exit_code = main()
  264. sys.exit(exit_code)