|
@@ -0,0 +1,128 @@
|
|
|
+import os
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+from typing import List, Dict
|
|
|
+import shutil
|
|
|
+
|
|
|
+class PerformanceAnalyzer:
|
|
|
+ def __init__(self, result_dir: str):
|
|
|
+ self.result_dir = result_dir
|
|
|
+ self.csv_path = os.path.join(result_dir, 'detection_report.csv')
|
|
|
+ self.targets_dir = os.path.join(result_dir, 'targets_all')
|
|
|
+ self.true_targets_dir = os.path.join(result_dir, 'targets_all_True')
|
|
|
+
|
|
|
+ # 确保目录存在
|
|
|
+ os.makedirs(self.true_targets_dir, exist_ok=True)
|
|
|
+
|
|
|
+ def analyze(self) -> Dict:
|
|
|
+ """分析检测性能"""
|
|
|
+ # 读取CSV文件
|
|
|
+ df = pd.read_csv(self.csv_path)
|
|
|
+
|
|
|
+ # 获取所有检测到的目标
|
|
|
+ all_targets = set(os.listdir(self.targets_dir))
|
|
|
+ true_targets = set(os.listdir(self.true_targets_dir))
|
|
|
+
|
|
|
+ # 获取总样本数(CSV中唯一的图像文件数)
|
|
|
+ total_samples = len(df['Image File'].unique())
|
|
|
+
|
|
|
+ # 初始化统计
|
|
|
+ stats = {
|
|
|
+ 'TP': 0, # 真阳性
|
|
|
+ 'FP': 0, # 假阳性
|
|
|
+ 'FN': 0, # 假阴性
|
|
|
+ 'TN': 0 # 真阴性
|
|
|
+ }
|
|
|
+
|
|
|
+ # 计算真阳性(TP):targets_all_True文件夹中的所有文件
|
|
|
+ stats['TP'] = len(true_targets)
|
|
|
+
|
|
|
+ # 计算假阳性(FP):在targets_all中但不在targets_all_True中的文件
|
|
|
+ stats['FP'] = len(all_targets - true_targets)
|
|
|
+
|
|
|
+ # 计算真阴性(TN):总样本数减去TP和FP
|
|
|
+ stats['TN'] = total_samples - (stats['TP'] + stats['FP'])
|
|
|
+
|
|
|
+ # 计算性能指标
|
|
|
+ total = stats['TP'] + stats['FP'] + stats['TN']
|
|
|
+ accuracy = (stats['TP'] + stats['TN']) / total if total > 0 else 0
|
|
|
+ precision = stats['TP'] / (stats['TP'] + stats['FP']) if (stats['TP'] + stats['FP']) > 0 else 0
|
|
|
+ recall = stats['TP'] / (stats['TP'] + stats['FN']) if (stats['TP'] + stats['FN']) > 0 else 0
|
|
|
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
+ # 计算误报率(虚警率):FP / (TN + FP)
|
|
|
+ false_alarm_rate = stats['FP'] / (stats['TN'] + stats['FP']) if (stats['TN'] + stats['FP']) > 0 else 0
|
|
|
+
|
|
|
+ # 打印详细统计信息
|
|
|
+ print("\n详细统计信息:")
|
|
|
+ print(f"总样本数(CSV中唯一图像数): {total_samples}")
|
|
|
+ print(f"检测到的目标总数: {len(all_targets)}")
|
|
|
+ print(f"真实目标总数(TP): {len(true_targets)}")
|
|
|
+ print(f"误报目标数(FP): {len(all_targets - true_targets)}")
|
|
|
+ print(f"真阴性数量(TN): {stats['TN']}")
|
|
|
+
|
|
|
+ # 保存统计结果到文件
|
|
|
+ stats_file = os.path.join(self.result_dir, 'performance_stats.txt')
|
|
|
+ with open(stats_file, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("检测性能统计报告\n")
|
|
|
+ f.write("=" * 50 + "\n\n")
|
|
|
+
|
|
|
+ f.write("详细统计信息:\n")
|
|
|
+ f.write(f"总样本数(CSV中唯一图像数): {total_samples}\n")
|
|
|
+ f.write(f"检测到的目标总数: {len(all_targets)}\n")
|
|
|
+ f.write(f"真实目标总数(TP): {len(true_targets)}\n")
|
|
|
+ f.write(f"误报目标数(FP): {len(all_targets - true_targets)}\n")
|
|
|
+ f.write(f"真阴性数量(TN): {stats['TN']}\n\n")
|
|
|
+
|
|
|
+ f.write("性能指标:\n")
|
|
|
+ f.write(f"准确率: {accuracy:.4f}\n")
|
|
|
+ f.write(f"精确率: {precision:.4f}\n")
|
|
|
+ f.write(f"召回率: {recall:.4f}\n")
|
|
|
+ f.write(f"F1分数: {f1_score:.4f}\n")
|
|
|
+ f.write(f"误报率(虚警率): {false_alarm_rate:.4f}\n")
|
|
|
+
|
|
|
+ print(f"\n统计结果已保存到: {stats_file}")
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'stats': stats,
|
|
|
+ 'metrics': {
|
|
|
+ 'accuracy': accuracy,
|
|
|
+ 'precision': precision,
|
|
|
+ 'recall': recall,
|
|
|
+ 'f1_score': f1_score,
|
|
|
+ 'false_alarm_rate': false_alarm_rate
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ def mark_true_targets(self, true_target_files: List[str]):
|
|
|
+ """标记真实目标"""
|
|
|
+ for target_file in true_target_files:
|
|
|
+ src = os.path.join(self.targets_dir, target_file)
|
|
|
+ dst = os.path.join(self.true_targets_dir, target_file)
|
|
|
+ if os.path.exists(src):
|
|
|
+ shutil.copy2(src, dst)
|
|
|
+ print(f'已标记真实目标: {target_file}')
|
|
|
+
|
|
|
+def main():
|
|
|
+ # 示例使用
|
|
|
+ result_dir = r"D:\PythonProject\Model\output_20250329_130338_results"
|
|
|
+ analyzer = PerformanceAnalyzer(result_dir)
|
|
|
+
|
|
|
+ # 分析性能
|
|
|
+ results = analyzer.analyze()
|
|
|
+
|
|
|
+ # 打印结果
|
|
|
+ print("\n检测性能统计:")
|
|
|
+ print(f"真阳性(TP): {results['stats']['TP']}")
|
|
|
+ print(f"假阳性(FP): {results['stats']['FP']}")
|
|
|
+ print(f"假阴性(FN): {results['stats']['FN']}")
|
|
|
+ print(f"真阴性(TN): {results['stats']['TN']}")
|
|
|
+
|
|
|
+ print("\n性能指标:")
|
|
|
+ print(f"准确率: {results['metrics']['accuracy']:.4f}")
|
|
|
+ print(f"精确率: {results['metrics']['precision']:.4f}")
|
|
|
+ print(f"召回率: {results['metrics']['recall']:.4f}")
|
|
|
+ print(f"F1分数: {results['metrics']['f1_score']:.4f}")
|
|
|
+ print(f"误报率(虚警率): {results['metrics']['false_alarm_rate']:.4f}")
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|