model_evaluator.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模型性能评估工具
  5. 用于评估UAV和鸟类检测模型的各项性能指标
  6. 包括:误报率、漏报率、分类准确率、推理速度等
  7. 作者: AI Assistant
  8. 创建时间: 2024
  9. """
  10. import os
  11. import cv2
  12. import numpy as np
  13. import json
  14. import time
  15. import random
  16. from datetime import datetime
  17. from typing import List, Dict, Tuple, Optional
  18. from glob import glob
  19. from tqdm import tqdm
  20. import matplotlib.pyplot as plt
  21. import seaborn as sns
  22. from collections import defaultdict
  23. import pandas as pd
  24. # 导入现有的推理模块
  25. from src.core.inference import ONNXDetector
  26. class ModelEvaluator:
  27. """
  28. 模型性能评估器
  29. """
  30. def __init__(self, model_path: str = None, model_type: str = 'uav_and_bird',
  31. threshold: float = 0.5, iou_threshold: float = 0.5):
  32. """
  33. 初始化评估器
  34. Args:
  35. model_path: 模型文件路径
  36. model_type: 模型类型
  37. threshold: 置信度阈值
  38. iou_threshold: IoU阈值
  39. """
  40. self.model_type = model_type
  41. self.threshold = threshold
  42. self.iou_threshold = iou_threshold
  43. # 初始化检测器
  44. self.detector = ONNXDetector(
  45. threshold=threshold,
  46. model_type=model_type,
  47. model_path=model_path,
  48. save_empty=False
  49. )
  50. # 评估结果存储
  51. self.evaluation_results = {
  52. 'true_positives': 0,
  53. 'false_positives': 0,
  54. 'false_negatives': 0,
  55. 'true_negatives': 0,
  56. 'inference_times': [],
  57. 'detailed_results': [],
  58. 'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
  59. }
  60. # 类别映射
  61. if model_type == 'uav_and_bird':
  62. self.class_names = {0: 'Bird', 1: 'Drone'}
  63. self.class_mapping = {'bird': 0, 'drone': 1, 'none': -1}
  64. else:
  65. self.class_names = {0: 'Drone'}
  66. self.class_mapping = {'drone': 0, 'none': -1}
  67. def load_ground_truth_from_folder_structure(self, data_dir: str) -> List[Dict]:
  68. """
  69. 从文件夹结构加载ground truth
  70. 假设文件夹结构为:
  71. data_dir/
  72. ├── bird/
  73. ├── drone/
  74. └── none/
  75. Args:
  76. data_dir: 数据目录路径
  77. Returns:
  78. 包含图像路径和标签的字典列表
  79. """
  80. ground_truth = []
  81. # 支持的图像格式
  82. image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff']
  83. # 遍历各个类别文件夹
  84. for class_folder in ['bird', 'drone', 'none']:
  85. folder_path = os.path.join(data_dir, class_folder)
  86. if not os.path.exists(folder_path):
  87. continue
  88. # 获取该文件夹下的所有图像
  89. images = set() # 使用集合避免重复计算
  90. for ext in image_extensions:
  91. # 搜索小写扩展名
  92. images.update(glob(os.path.join(folder_path, ext)))
  93. # 搜索大写扩展名
  94. images.update(glob(os.path.join(folder_path, ext.upper())))
  95. # 确定类别标签
  96. if class_folder == 'bird':
  97. label = 'bird'
  98. elif class_folder == 'drone':
  99. label = 'drone'
  100. else:
  101. label = 'none'
  102. # 添加到ground truth
  103. for img_path in images:
  104. ground_truth.append({
  105. 'image_path': img_path,
  106. 'label': label,
  107. 'class_id': self.class_mapping.get(label, -1),
  108. 'has_target': label != 'none'
  109. })
  110. return ground_truth
  111. def load_ground_truth_from_annotation(self, annotation_file: str) -> List[Dict]:
  112. """
  113. 从标注文件加载ground truth
  114. 支持COCO格式或自定义JSON格式
  115. Args:
  116. annotation_file: 标注文件路径
  117. Returns:
  118. 包含图像路径和标签的字典列表
  119. """
  120. with open(annotation_file, 'r', encoding='utf-8') as f:
  121. annotations = json.load(f)
  122. ground_truth = []
  123. # 处理自定义格式
  124. if 'images' in annotations:
  125. for img_info in annotations['images']:
  126. ground_truth.append({
  127. 'image_path': img_info['path'],
  128. 'label': img_info.get('label', 'none'),
  129. 'class_id': self.class_mapping.get(img_info.get('label', 'none'), -1),
  130. 'has_target': img_info.get('has_target', False),
  131. 'bboxes': img_info.get('bboxes', [])
  132. })
  133. return ground_truth
  134. def calculate_iou(self, box1: List[float], box2: List[float]) -> float:
  135. """
  136. 计算两个边界框的IoU
  137. Args:
  138. box1: [x1, y1, x2, y2]
  139. box2: [x1, y1, x2, y2]
  140. Returns:
  141. IoU值
  142. """
  143. x1 = max(box1[0], box2[0])
  144. y1 = max(box1[1], box2[1])
  145. x2 = min(box1[2], box2[2])
  146. y2 = min(box1[3], box2[3])
  147. if x2 <= x1 or y2 <= y1:
  148. return 0.0
  149. intersection = (x2 - x1) * (y2 - y1)
  150. area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
  151. area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
  152. union = area1 + area2 - intersection
  153. return intersection / union if union > 0 else 0.0
  154. def evaluate_single_image(self, image_info: Dict) -> Dict:
  155. """
  156. 评估单张图像
  157. Args:
  158. image_info: 包含图像路径和ground truth的字典
  159. Returns:
  160. 评估结果字典
  161. """
  162. image_path = image_info['image_path']
  163. gt_label = image_info['label']
  164. gt_class_id = image_info['class_id']
  165. gt_has_target = image_info['has_target']
  166. gt_bboxes = image_info.get('bboxes', [])
  167. # 记录推理时间
  168. start_time = time.time()
  169. # 进行推理
  170. try:
  171. # 使用detector的方法进行推理
  172. input_data = self.detector.preprocess(image_path)
  173. if input_data is None:
  174. raise Exception("图像预处理失败")
  175. # 执行推理
  176. detections_raw = self.detector.inference(input_data)
  177. # 后处理获取检测结果
  178. image_orig = cv2.imread(image_path)
  179. detection_count, processed_img, detections_list = self.detector.postprocess(
  180. detections_raw, image_orig, image_path
  181. )
  182. inference_time = time.time() - start_time
  183. except Exception as e:
  184. print(f"推理失败: {image_path}, 错误: {e}")
  185. return {
  186. 'image_path': image_path,
  187. 'gt_label': gt_label,
  188. 'prediction': 'error',
  189. 'inference_time': 0,
  190. 'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0
  191. }
  192. # 分析结果
  193. has_detection = detection_count > 0
  194. # 计算混淆矩阵元素
  195. tp = fp = fn = tn = 0
  196. if gt_has_target and has_detection:
  197. tp = 1 # 有目标且检测到
  198. elif gt_has_target and not has_detection:
  199. fn = 1 # 有目标但未检测到(漏报)
  200. elif not gt_has_target and has_detection:
  201. fp = 1 # 无目标但检测到(误报)
  202. elif not gt_has_target and not has_detection:
  203. tn = 1 # 无目标且未检测到
  204. # 详细结果
  205. result = {
  206. 'image_path': image_path,
  207. 'gt_label': gt_label,
  208. 'gt_has_target': gt_has_target,
  209. 'detection_count': detection_count,
  210. 'has_detection': has_detection,
  211. 'inference_time': inference_time,
  212. 'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn,
  213. 'detections': detections_list if 'detections_list' in locals() else []
  214. }
  215. return result
  216. def balanced_sampling(self, ground_truth: List[Dict], total_samples: int,
  217. bird_ratio: float = 0.2, drone_ratio: float = 0.2,
  218. none_ratio: float = 0.6) -> List[Dict]:
  219. """
  220. 按比例平衡采样
  221. Args:
  222. ground_truth: 原始ground truth列表
  223. total_samples: 总采样数量
  224. bird_ratio: 鸟类样本比例
  225. drone_ratio: 无人机样本比例
  226. none_ratio: 无目标样本比例
  227. Returns:
  228. 平衡采样后的ground truth列表
  229. """
  230. # 按类别分组
  231. birds = [item for item in ground_truth if item['label'] == 'bird']
  232. drones = [item for item in ground_truth if item['label'] == 'drone']
  233. nones = [item for item in ground_truth if item['label'] == 'none']
  234. # 计算各类别需要的样本数
  235. bird_count = int(total_samples * bird_ratio)
  236. drone_count = int(total_samples * drone_ratio)
  237. none_count = int(total_samples * none_ratio)
  238. # 确保不超过实际可用样本数
  239. bird_count = min(bird_count, len(birds))
  240. drone_count = min(drone_count, len(drones))
  241. none_count = min(none_count, len(nones))
  242. print(f"平衡采样策略:")
  243. print(f" 鸟类: {bird_count}/{len(birds)} ({bird_ratio*100:.1f}%)")
  244. print(f" 无人机: {drone_count}/{len(drones)} ({drone_ratio*100:.1f}%)")
  245. print(f" 无目标: {none_count}/{len(nones)} ({none_ratio*100:.1f}%)")
  246. print(f" 总计: {bird_count + drone_count + none_count} 张图像")
  247. # 随机采样各类别
  248. sampled_data = []
  249. if bird_count > 0:
  250. sampled_data.extend(random.sample(birds, bird_count))
  251. if drone_count > 0:
  252. sampled_data.extend(random.sample(drones, drone_count))
  253. if none_count > 0:
  254. sampled_data.extend(random.sample(nones, none_count))
  255. # 打乱顺序
  256. random.shuffle(sampled_data)
  257. return sampled_data
  258. def evaluate_dataset(self, ground_truth: List[Dict], sample_size: Optional[int] = None,
  259. balanced_sampling: bool = False, bird_ratio: float = 0.2,
  260. drone_ratio: float = 0.2, none_ratio: float = 0.6) -> Dict:
  261. """
  262. 评估整个数据集
  263. Args:
  264. ground_truth: ground truth列表
  265. sample_size: 采样大小,None表示使用全部数据
  266. balanced_sampling: 是否使用平衡采样
  267. bird_ratio: 鸟类样本比例(仅在balanced_sampling=True时有效)
  268. drone_ratio: 无人机样本比例(仅在balanced_sampling=True时有效)
  269. none_ratio: 无目标样本比例(仅在balanced_sampling=True时有效)
  270. Returns:
  271. 评估结果字典
  272. """
  273. # 采样策略
  274. if sample_size and sample_size < len(ground_truth):
  275. if balanced_sampling:
  276. # 平衡采样
  277. ground_truth = self.balanced_sampling(
  278. ground_truth, sample_size, bird_ratio, drone_ratio, none_ratio
  279. )
  280. else:
  281. # 随机采样
  282. ground_truth = random.sample(ground_truth, sample_size)
  283. print(f"随机采样 {sample_size} 张图像进行评估")
  284. elif balanced_sampling:
  285. print("警告: 启用了平衡采样但未指定sample_size,将使用全部数据")
  286. print(f"开始评估 {len(ground_truth)} 张图像...")
  287. # 重置结果
  288. self.evaluation_results = {
  289. 'true_positives': 0,
  290. 'false_positives': 0,
  291. 'false_negatives': 0,
  292. 'true_negatives': 0,
  293. 'inference_times': [],
  294. 'detailed_results': [],
  295. 'class_metrics': defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
  296. }
  297. # 逐张评估
  298. for img_info in tqdm(ground_truth, desc="评估进度"):
  299. result = self.evaluate_single_image(img_info)
  300. # 累计统计
  301. self.evaluation_results['true_positives'] += result['tp']
  302. self.evaluation_results['false_positives'] += result['fp']
  303. self.evaluation_results['false_negatives'] += result['fn']
  304. self.evaluation_results['true_negatives'] += result['tn']
  305. self.evaluation_results['inference_times'].append(result['inference_time'])
  306. self.evaluation_results['detailed_results'].append(result)
  307. # 按类别统计
  308. gt_label = result['gt_label']
  309. # 统计所有类别的指标,包括none类别
  310. self.evaluation_results['class_metrics'][gt_label]['tp'] += result['tp']
  311. self.evaluation_results['class_metrics'][gt_label]['fp'] += result['fp']
  312. self.evaluation_results['class_metrics'][gt_label]['fn'] += result['fn']
  313. return self.calculate_metrics()
  314. def calculate_metrics(self) -> Dict:
  315. """
  316. 计算各种性能指标
  317. Returns:
  318. 性能指标字典
  319. """
  320. tp = self.evaluation_results['true_positives']
  321. fp = self.evaluation_results['false_positives']
  322. fn = self.evaluation_results['false_negatives']
  323. tn = self.evaluation_results['true_negatives']
  324. # 基本指标
  325. precision = tp / (tp + fp) if (tp + fp) > 0 else 0
  326. recall = tp / (tp + fn) if (tp + fn) > 0 else 0
  327. f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
  328. accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0
  329. # 误报率和漏报率
  330. false_positive_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
  331. false_negative_rate = fn / (fn + tp) if (fn + tp) > 0 else 0
  332. # 推理速度统计
  333. inference_times = self.evaluation_results['inference_times']
  334. avg_inference_time = np.mean(inference_times) if inference_times else 0
  335. fps = 1 / avg_inference_time if avg_inference_time > 0 else 0
  336. metrics = {
  337. 'confusion_matrix': {
  338. 'true_positives': tp,
  339. 'false_positives': fp,
  340. 'false_negatives': fn,
  341. 'true_negatives': tn
  342. },
  343. 'performance_metrics': {
  344. 'precision': precision,
  345. 'recall': recall,
  346. 'f1_score': f1_score,
  347. 'accuracy': accuracy,
  348. 'false_positive_rate': false_positive_rate,
  349. 'false_negative_rate': false_negative_rate
  350. },
  351. 'speed_metrics': {
  352. 'avg_inference_time': avg_inference_time,
  353. 'fps': fps,
  354. 'total_images': len(inference_times),
  355. 'total_time': sum(inference_times)
  356. },
  357. 'class_metrics': {}
  358. }
  359. # 按类别计算指标
  360. for class_name, class_stats in self.evaluation_results['class_metrics'].items():
  361. class_tp = class_stats['tp']
  362. class_fp = class_stats['fp']
  363. class_fn = class_stats['fn']
  364. class_precision = class_tp / (class_tp + class_fp) if (class_tp + class_fp) > 0 else 0
  365. class_recall = class_tp / (class_tp + class_fn) if (class_tp + class_fn) > 0 else 0
  366. class_f1 = 2 * (class_precision * class_recall) / (class_precision + class_recall) if (class_precision + class_recall) > 0 else 0
  367. metrics['class_metrics'][class_name] = {
  368. 'precision': class_precision,
  369. 'recall': class_recall,
  370. 'f1_score': class_f1,
  371. 'tp': class_tp,
  372. 'fp': class_fp,
  373. 'fn': class_fn
  374. }
  375. return metrics
  376. def generate_report(self, metrics: Dict, output_dir: str = None) -> str:
  377. """
  378. 生成评估报告
  379. Args:
  380. metrics: 性能指标字典
  381. output_dir: 输出目录
  382. Returns:
  383. 报告文件路径
  384. """
  385. if output_dir is None:
  386. output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
  387. os.makedirs(output_dir, exist_ok=True)
  388. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  389. report_file = os.path.join(output_dir, f'evaluation_report_{timestamp}.txt')
  390. with open(report_file, 'w', encoding='utf-8') as f:
  391. f.write("=" * 60 + "\n")
  392. f.write("模型性能评估报告\n")
  393. f.write("=" * 60 + "\n")
  394. f.write(f"评估时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  395. f.write(f"模型类型: {self.model_type}\n")
  396. f.write(f"置信度阈值: {self.threshold}\n")
  397. f.write(f"IoU阈值: {self.iou_threshold}\n\n")
  398. # 混淆矩阵
  399. cm = metrics['confusion_matrix']
  400. f.write("混淆矩阵:\n")
  401. f.write("-" * 30 + "\n")
  402. f.write(f"真正例 (TP): {cm['true_positives']}\n")
  403. f.write(f"假正例 (FP): {cm['false_positives']}\n")
  404. f.write(f"假负例 (FN): {cm['false_negatives']}\n")
  405. f.write(f"真负例 (TN): {cm['true_negatives']}\n\n")
  406. # 性能指标
  407. pm = metrics['performance_metrics']
  408. f.write("性能指标:\n")
  409. f.write("-" * 30 + "\n")
  410. f.write(f"精确率 (Precision): {pm['precision']:.4f}\n")
  411. f.write(f"召回率 (Recall): {pm['recall']:.4f}\n")
  412. f.write(f"F1分数: {pm['f1_score']:.4f}\n")
  413. f.write(f"准确率 (Accuracy): {pm['accuracy']:.4f}\n")
  414. f.write(f"误报率 (FPR): {pm['false_positive_rate']:.4f}\n")
  415. f.write(f"漏报率 (FNR): {pm['false_negative_rate']:.4f}\n\n")
  416. # 速度指标
  417. sm = metrics['speed_metrics']
  418. f.write("速度指标:\n")
  419. f.write("-" * 30 + "\n")
  420. f.write(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒\n")
  421. f.write(f"推理速度 (FPS): {sm['fps']:.2f}\n")
  422. f.write(f"总图像数: {sm['total_images']}\n")
  423. f.write(f"总耗时: {sm['total_time']:.2f} 秒\n\n")
  424. # 按类别指标
  425. if metrics['class_metrics']:
  426. f.write("按类别性能指标:\n")
  427. f.write("-" * 30 + "\n")
  428. for class_name, class_metrics in metrics['class_metrics'].items():
  429. f.write(f"\n{class_name.upper()}类别:\n")
  430. f.write(f" 精确率: {class_metrics['precision']:.4f}\n")
  431. f.write(f" 召回率: {class_metrics['recall']:.4f}\n")
  432. f.write(f" F1分数: {class_metrics['f1_score']:.4f}\n")
  433. f.write(f" TP: {class_metrics['tp']}, FP: {class_metrics['fp']}, FN: {class_metrics['fn']}\n")
  434. print(f"评估报告已保存到: {report_file}")
  435. return report_file
  436. def plot_results(self, metrics: Dict, output_dir: str = None):
  437. """
  438. 绘制评估结果图表
  439. Args:
  440. metrics: 性能指标字典
  441. output_dir: 输出目录
  442. """
  443. if output_dir is None:
  444. output_dir = os.path.join(os.path.dirname(__file__), 'evaluation_results')
  445. os.makedirs(output_dir, exist_ok=True)
  446. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  447. # 设置中文字体
  448. plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
  449. plt.rcParams['axes.unicode_minus'] = False
  450. # 1. 混淆矩阵热力图
  451. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
  452. cm = metrics['confusion_matrix']
  453. confusion_data = np.array([[cm['true_positives'], cm['false_negatives']],
  454. [cm['false_positives'], cm['true_negatives']]])
  455. sns.heatmap(confusion_data, annot=True, fmt='d', cmap='Blues',
  456. xticklabels=['预测正例', '预测负例'],
  457. yticklabels=['实际正例', '实际负例'], ax=ax1)
  458. ax1.set_title('混淆矩阵')
  459. # 2. 性能指标柱状图
  460. pm = metrics['performance_metrics']
  461. metrics_names = ['精确率', '召回率', 'F1分数', '准确率']
  462. metrics_values = [pm['precision'], pm['recall'], pm['f1_score'], pm['accuracy']]
  463. bars = ax2.bar(metrics_names, metrics_values, color=['skyblue', 'lightgreen', 'orange', 'pink'])
  464. ax2.set_title('性能指标')
  465. ax2.set_ylim(0, 1)
  466. ax2.set_ylabel('分数')
  467. # 在柱状图上添加数值标签
  468. for bar, value in zip(bars, metrics_values):
  469. ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
  470. f'{value:.3f}', ha='center', va='bottom')
  471. # 3. 推理时间分布直方图
  472. inference_times = self.evaluation_results['inference_times']
  473. if inference_times:
  474. ax3.hist(inference_times, bins=30, color='lightcoral', alpha=0.7)
  475. ax3.set_title('推理时间分布')
  476. ax3.set_xlabel('推理时间 (秒)')
  477. ax3.set_ylabel('频次')
  478. ax3.axvline(np.mean(inference_times), color='red', linestyle='--',
  479. label=f'平均值: {np.mean(inference_times):.3f}s')
  480. ax3.legend()
  481. # 4. 按类别性能对比
  482. if metrics['class_metrics']:
  483. class_names = list(metrics['class_metrics'].keys())
  484. precisions = [metrics['class_metrics'][name]['precision'] for name in class_names]
  485. recalls = [metrics['class_metrics'][name]['recall'] for name in class_names]
  486. f1_scores = [metrics['class_metrics'][name]['f1_score'] for name in class_names]
  487. x = np.arange(len(class_names))
  488. width = 0.25
  489. ax4.bar(x - width, precisions, width, label='精确率', color='skyblue')
  490. ax4.bar(x, recalls, width, label='召回率', color='lightgreen')
  491. ax4.bar(x + width, f1_scores, width, label='F1分数', color='orange')
  492. ax4.set_title('按类别性能对比')
  493. ax4.set_xlabel('类别')
  494. ax4.set_ylabel('分数')
  495. ax4.set_xticks(x)
  496. ax4.set_xticklabels(class_names)
  497. ax4.legend()
  498. ax4.set_ylim(0, 1)
  499. plt.tight_layout()
  500. plot_file = os.path.join(output_dir, f'evaluation_plots_{timestamp}.png')
  501. plt.savefig(plot_file, dpi=300, bbox_inches='tight')
  502. plt.close()
  503. print(f"评估图表已保存到: {plot_file}")
  504. return plot_file
  505. def main():
  506. """
  507. 主函数 - 演示如何使用评估器
  508. """
  509. print("模型性能评估工具")
  510. print("=" * 50)
  511. # 配置参数
  512. data_dir = r"d:\PythonProject\R360-UAVmodelTool\Data"
  513. model_type = 'uav_and_bird'
  514. threshold = 0.5
  515. sample_size = None # None表示使用全部数据,可以设置为具体数字进行采样
  516. # 初始化评估器
  517. evaluator = ModelEvaluator(
  518. model_type=model_type,
  519. threshold=threshold
  520. )
  521. # 加载ground truth
  522. print("加载ground truth数据...")
  523. ground_truth = evaluator.load_ground_truth_from_folder_structure(data_dir)
  524. if not ground_truth:
  525. print(f"警告: 在 {data_dir} 中未找到测试数据")
  526. print("请确保数据目录结构如下:")
  527. print("Data/")
  528. print("├── BIRDS/ # 鸟类图像")
  529. print("├── UAV/ # 无人机图像")
  530. print("└── OTHERS/ # 其他图像(负样本)")
  531. return
  532. print(f"找到 {len(ground_truth)} 张测试图像")
  533. # 显示数据分布
  534. label_counts = {}
  535. for item in ground_truth:
  536. label = item['label']
  537. label_counts[label] = label_counts.get(label, 0) + 1
  538. print("数据分布:")
  539. for label, count in label_counts.items():
  540. print(f" {label}: {count} 张")
  541. # 开始评估
  542. print("\n开始模型评估...")
  543. metrics = evaluator.evaluate_dataset(ground_truth, sample_size)
  544. # 生成报告
  545. print("\n生成评估报告...")
  546. report_file = evaluator.generate_report(metrics)
  547. # 绘制图表
  548. print("生成评估图表...")
  549. plot_file = evaluator.plot_results(metrics)
  550. # 打印关键指标
  551. print("\n=== 评估结果摘要 ===")
  552. pm = metrics['performance_metrics']
  553. sm = metrics['speed_metrics']
  554. print(f"精确率: {pm['precision']:.4f}")
  555. print(f"召回率: {pm['recall']:.4f}")
  556. print(f"F1分数: {pm['f1_score']:.4f}")
  557. print(f"准确率: {pm['accuracy']:.4f}")
  558. print(f"误报率: {pm['false_positive_rate']:.4f}")
  559. print(f"漏报率: {pm['false_negative_rate']:.4f}")
  560. print(f"平均推理时间: {sm['avg_inference_time']:.4f} 秒")
  561. print(f"推理速度: {sm['fps']:.2f} FPS")
  562. print(f"\n详细报告: {report_file}")
  563. print(f"评估图表: {plot_file}")
  564. if __name__ == "__main__":
  565. main()