123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import os
- import pandas as pd
- import numpy as np
- from typing import List, Dict
- import shutil
- class PerformanceAnalyzer:
- def __init__(self, result_dir: str):
- self.result_dir = result_dir
- self.csv_path = os.path.join(result_dir, 'detection_report.csv')
- self.targets_dir = os.path.join(result_dir, 'targets_all')
- self.true_targets_dir = os.path.join(result_dir, 'targets_all_True')
-
- # 确保目录存在
- os.makedirs(self.true_targets_dir, exist_ok=True)
-
- def analyze(self) -> Dict:
- """分析检测性能"""
- # 读取CSV文件
- df = pd.read_csv(self.csv_path)
-
- # 获取所有检测到的目标
- all_targets = set(os.listdir(self.targets_dir))
- true_targets = set(os.listdir(self.true_targets_dir))
-
- # 获取总样本数(CSV中唯一的图像文件数)
- total_samples = len(df['Image File'].unique())
-
- # 初始化统计
- stats = {
- 'TP': 0, # 真阳性
- 'FP': 0, # 假阳性
- 'FN': 0, # 假阴性
- 'TN': 0 # 真阴性
- }
-
- # 计算真阳性(TP):targets_all_True文件夹中的所有文件
- stats['TP'] = len(true_targets)
-
- # 计算假阳性(FP):在targets_all中但不在targets_all_True中的文件
- stats['FP'] = len(all_targets - true_targets)
-
- # 计算真阴性(TN):总样本数减去TP和FP
- stats['TN'] = total_samples - (stats['TP'] + stats['FP'])
-
- # 计算性能指标
- total = stats['TP'] + stats['FP'] + stats['TN']
- accuracy = (stats['TP'] + stats['TN']) / total if total > 0 else 0
- precision = stats['TP'] / (stats['TP'] + stats['FP']) if (stats['TP'] + stats['FP']) > 0 else 0
- recall = stats['TP'] / (stats['TP'] + stats['FN']) if (stats['TP'] + stats['FN']) > 0 else 0
- f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
- # 计算误报率(虚警率):FP / (TN + FP)
- false_alarm_rate = stats['FP'] / (stats['TN'] + stats['FP']) if (stats['TN'] + stats['FP']) > 0 else 0
-
- # 打印详细统计信息
- print("\n详细统计信息:")
- print(f"总样本数(CSV中唯一图像数): {total_samples}")
- print(f"检测到的目标总数: {len(all_targets)}")
- print(f"真实目标总数(TP): {len(true_targets)}")
- print(f"误报目标数(FP): {len(all_targets - true_targets)}")
- print(f"真阴性数量(TN): {stats['TN']}")
-
- # 保存统计结果到文件
- stats_file = os.path.join(self.result_dir, 'performance_stats.txt')
- with open(stats_file, 'w', encoding='utf-8') as f:
- f.write("检测性能统计报告\n")
- f.write("=" * 50 + "\n\n")
-
- f.write("详细统计信息:\n")
- f.write(f"总样本数(CSV中唯一图像数): {total_samples}\n")
- f.write(f"检测到的目标总数: {len(all_targets)}\n")
- f.write(f"真实目标总数(TP): {len(true_targets)}\n")
- f.write(f"误报目标数(FP): {len(all_targets - true_targets)}\n")
- f.write(f"真阴性数量(TN): {stats['TN']}\n\n")
-
- f.write("性能指标:\n")
- f.write(f"准确率: {accuracy:.4f}\n")
- f.write(f"精确率: {precision:.4f}\n")
- f.write(f"召回率: {recall:.4f}\n")
- f.write(f"F1分数: {f1_score:.4f}\n")
- f.write(f"误报率(虚警率): {false_alarm_rate:.4f}\n")
-
- print(f"\n统计结果已保存到: {stats_file}")
-
- return {
- 'stats': stats,
- 'metrics': {
- 'accuracy': accuracy,
- 'precision': precision,
- 'recall': recall,
- 'f1_score': f1_score,
- 'false_alarm_rate': false_alarm_rate
- }
- }
-
- def mark_true_targets(self, true_target_files: List[str]):
- """标记真实目标"""
- for target_file in true_target_files:
- src = os.path.join(self.targets_dir, target_file)
- dst = os.path.join(self.true_targets_dir, target_file)
- if os.path.exists(src):
- shutil.copy2(src, dst)
- print(f'已标记真实目标: {target_file}')
- def main():
- # 示例使用
- result_dir = r"D:\PythonProject\Model\output_20250329_130338_results"
- analyzer = PerformanceAnalyzer(result_dir)
-
- # 分析性能
- results = analyzer.analyze()
-
- # 打印结果
- print("\n检测性能统计:")
- print(f"真阳性(TP): {results['stats']['TP']}")
- print(f"假阳性(FP): {results['stats']['FP']}")
- print(f"假阴性(FN): {results['stats']['FN']}")
- print(f"真阴性(TN): {results['stats']['TN']}")
-
- print("\n性能指标:")
- print(f"准确率: {results['metrics']['accuracy']:.4f}")
- print(f"精确率: {results['metrics']['precision']:.4f}")
- print(f"召回率: {results['metrics']['recall']:.4f}")
- print(f"F1分数: {results['metrics']['f1_score']:.4f}")
- print(f"误报率(虚警率): {results['metrics']['false_alarm_rate']:.4f}")
- if __name__ == '__main__':
- main()
|