performance_analyzer.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. import pandas as pd
  3. import numpy as np
  4. from typing import List, Dict
  5. import shutil
  6. class PerformanceAnalyzer:
  7. def __init__(self, result_dir: str):
  8. self.result_dir = result_dir
  9. self.csv_path = os.path.join(result_dir, 'detection_report.csv')
  10. self.targets_dir = os.path.join(result_dir, 'targets_all')
  11. self.true_targets_dir = os.path.join(result_dir, 'targets_all_True')
  12. # 确保目录存在
  13. os.makedirs(self.true_targets_dir, exist_ok=True)
  14. def analyze(self) -> Dict:
  15. """分析检测性能"""
  16. # 读取CSV文件
  17. df = pd.read_csv(self.csv_path)
  18. # 获取所有检测到的目标
  19. all_targets = set(os.listdir(self.targets_dir))
  20. true_targets = set(os.listdir(self.true_targets_dir))
  21. # 获取总样本数(CSV中唯一的图像文件数)
  22. total_samples = len(df['Image File'].unique())
  23. # 初始化统计
  24. stats = {
  25. 'TP': 0, # 真阳性
  26. 'FP': 0, # 假阳性
  27. 'FN': 0, # 假阴性
  28. 'TN': 0 # 真阴性
  29. }
  30. # 计算真阳性(TP):targets_all_True文件夹中的所有文件
  31. stats['TP'] = len(true_targets)
  32. # 计算假阳性(FP):在targets_all中但不在targets_all_True中的文件
  33. stats['FP'] = len(all_targets - true_targets)
  34. # 计算真阴性(TN):总样本数减去TP和FP
  35. stats['TN'] = total_samples - (stats['TP'] + stats['FP'])
  36. # 计算性能指标
  37. total = stats['TP'] + stats['FP'] + stats['TN']
  38. accuracy = (stats['TP'] + stats['TN']) / total if total > 0 else 0
  39. precision = stats['TP'] / (stats['TP'] + stats['FP']) if (stats['TP'] + stats['FP']) > 0 else 0
  40. recall = stats['TP'] / (stats['TP'] + stats['FN']) if (stats['TP'] + stats['FN']) > 0 else 0
  41. f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
  42. # 计算误报率(虚警率):FP / (TN + FP)
  43. false_alarm_rate = stats['FP'] / (stats['TN'] + stats['FP']) if (stats['TN'] + stats['FP']) > 0 else 0
  44. # 打印详细统计信息
  45. print("\n详细统计信息:")
  46. print(f"总样本数(CSV中唯一图像数): {total_samples}")
  47. print(f"检测到的目标总数: {len(all_targets)}")
  48. print(f"真实目标总数(TP): {len(true_targets)}")
  49. print(f"误报目标数(FP): {len(all_targets - true_targets)}")
  50. print(f"真阴性数量(TN): {stats['TN']}")
  51. # 保存统计结果到文件
  52. stats_file = os.path.join(self.result_dir, 'performance_stats.txt')
  53. with open(stats_file, 'w', encoding='utf-8') as f:
  54. f.write("检测性能统计报告\n")
  55. f.write("=" * 50 + "\n\n")
  56. f.write("详细统计信息:\n")
  57. f.write(f"总样本数(CSV中唯一图像数): {total_samples}\n")
  58. f.write(f"检测到的目标总数: {len(all_targets)}\n")
  59. f.write(f"真实目标总数(TP): {len(true_targets)}\n")
  60. f.write(f"误报目标数(FP): {len(all_targets - true_targets)}\n")
  61. f.write(f"真阴性数量(TN): {stats['TN']}\n\n")
  62. f.write("性能指标:\n")
  63. f.write(f"准确率: {accuracy:.4f}\n")
  64. f.write(f"精确率: {precision:.4f}\n")
  65. f.write(f"召回率: {recall:.4f}\n")
  66. f.write(f"F1分数: {f1_score:.4f}\n")
  67. f.write(f"误报率(虚警率): {false_alarm_rate:.4f}\n")
  68. print(f"\n统计结果已保存到: {stats_file}")
  69. return {
  70. 'stats': stats,
  71. 'metrics': {
  72. 'accuracy': accuracy,
  73. 'precision': precision,
  74. 'recall': recall,
  75. 'f1_score': f1_score,
  76. 'false_alarm_rate': false_alarm_rate
  77. }
  78. }
  79. def mark_true_targets(self, true_target_files: List[str]):
  80. """标记真实目标"""
  81. for target_file in true_target_files:
  82. src = os.path.join(self.targets_dir, target_file)
  83. dst = os.path.join(self.true_targets_dir, target_file)
  84. if os.path.exists(src):
  85. shutil.copy2(src, dst)
  86. print(f'已标记真实目标: {target_file}')
  87. def main():
  88. # 示例使用
  89. result_dir = r"D:\PythonProject\Model\output_20250329_130338_results"
  90. analyzer = PerformanceAnalyzer(result_dir)
  91. # 分析性能
  92. results = analyzer.analyze()
  93. # 打印结果
  94. print("\n检测性能统计:")
  95. print(f"真阳性(TP): {results['stats']['TP']}")
  96. print(f"假阳性(FP): {results['stats']['FP']}")
  97. print(f"假阴性(FN): {results['stats']['FN']}")
  98. print(f"真阴性(TN): {results['stats']['TN']}")
  99. print("\n性能指标:")
  100. print(f"准确率: {results['metrics']['accuracy']:.4f}")
  101. print(f"精确率: {results['metrics']['precision']:.4f}")
  102. print(f"召回率: {results['metrics']['recall']:.4f}")
  103. print(f"F1分数: {results['metrics']['f1_score']:.4f}")
  104. print(f"误报率(虚警率): {results['metrics']['false_alarm_rate']:.4f}")
  105. if __name__ == '__main__':
  106. main()