report_generator.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import csv
  2. import xlsxwriter
  3. from openpyxl import load_workbook
  4. from datetime import datetime
  5. from typing import List, Dict
  6. class ReportGenerator:
  7. def __init__(self, detector):
  8. self.metadata = {
  9. 'model_name': detector.model_path.split('/')[-1],
  10. 'model_version': detector.model_version,
  11. 'input_size': detector.input_size,
  12. 'test_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
  13. 'total_images': detector.image_count,
  14. 'confidence_threshold': detector.threshold
  15. }
  16. self.detections = detector.detection_records
  17. def generate_csv(self, output_path: str):
  18. with open(output_path, 'w', newline='', encoding='utf-8') as f:
  19. writer = csv.writer(f)
  20. writer.writerow(['Image File', 'Object Count', 'Max Confidence', 'BBox Center X', 'BBox Center Y', 'BBox Width', 'BBox Height', 'Normalized Coordinates'])
  21. for record in self.detections:
  22. if record['has_detection']:
  23. for detection in record['detections']:
  24. x1, y1, x2, y2 = detection['bbox']
  25. orig_w = detection.get('orig_w', 0)
  26. orig_h = detection.get('orig_h', 0)
  27. center_x = round((x1 + x2) / 2, 2)
  28. center_y = round((y1 + y2) / 2, 2)
  29. width = round(x2 - x1, 2)
  30. height = round(y2 - y1, 2)
  31. norm_cx = round(center_x / orig_w, 4) if orig_w > 0 else 0
  32. norm_cy = round(center_y / orig_h, 4) if orig_h > 0 else 0
  33. writer.writerow([
  34. record['image_path'],
  35. len(record['detections']),
  36. detection['confidence'],
  37. center_x,
  38. center_y,
  39. width,
  40. height,
  41. f"{norm_cx},{norm_cy},{width/orig_w},{height/orig_h}"
  42. ])
  43. else:
  44. writer.writerow([record['image_path'], 0, 0])
  45. def calculate_metrics(self, ground_truth: Dict[str, bool]) -> Dict[str, float]:
  46. """
  47. 计算模型评估指标
  48. :param ground_truth: 字典形式的地面真实值 {图像路径: 是否存在目标}
  49. :return: 包含各项指标的字典
  50. """
  51. tp = fp = tn = fn = 0
  52. for record in self.detections:
  53. actual = ground_truth.get(record['image_path'], False)
  54. predicted = record['has_detection']
  55. if actual and predicted:
  56. tp += 1
  57. elif actual and not predicted:
  58. fn += 1
  59. elif not actual and predicted:
  60. fp += 1
  61. else:
  62. tn += 1
  63. # 防止除零错误
  64. fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
  65. fnr = fn / (fn + tp) if (fn + tp) > 0 else 0
  66. return {
  67. 'true_positive': tp,
  68. 'false_positive': fp,
  69. 'true_negative': tn,
  70. 'false_negative': fn,
  71. 'fpr': round(fpr, 4),
  72. 'fnr': round(fnr, 4)
  73. }
  74. def generate_excel(self, output_path: str, calculate_metrics: bool = False, ground_truth_file: str = None):
  75. workbook = xlsxwriter.Workbook(output_path)
  76. meta_ws = workbook.add_worksheet('Model Info')
  77. detail_ws = workbook.add_worksheet('Detection Results')
  78. # 写入元数据
  79. meta_headers = ['Model Name', 'Version', 'Input Size', 'Test Time', 'Image Count', 'Confidence Threshold']
  80. meta_ws.write_row(0, 0, meta_headers)
  81. meta_ws.write_row(1, 0, [
  82. self.metadata['model_name'],
  83. self.metadata['model_version'],
  84. self.metadata['input_size'],
  85. self.metadata['test_time'],
  86. self.metadata['total_images'],
  87. self.metadata['confidence_threshold']
  88. ])
  89. # 写入检测明细
  90. detail_headers = ['Image File', 'Object Count', 'Max Confidence', 'Avg Confidence', 'BBox Center X', 'BBox Center Y', 'BBox Width', 'BBox Height', 'Normalized Coordinates', 'Review Result', 'Remarks']
  91. detail_ws.write_row(0, 0, detail_headers)
  92. for row_idx, record in enumerate(self.detections, start=1):
  93. if not record['has_detection']:
  94. detail_ws.write_row(row_idx, 0, [
  95. record['image_path'],
  96. 0, 0, 0, '', '', '', '', '', '', ''
  97. ])
  98. continue
  99. # 每个检测框单独生成一行数据
  100. for detection in record['detections']:
  101. x1, y1, x2, y2 = detection['bbox']
  102. orig_w = detection.get('orig_w', 0)
  103. orig_h = detection.get('orig_h', 0)
  104. # 计算绝对坐标
  105. center_x = round((x1 + x2) / 2, 2)
  106. center_y = round((y1 + y2) / 2, 2)
  107. width = round(x2 - x1, 2)
  108. height = round(y2 - y1, 2)
  109. # 计算归一化坐标
  110. norm_cx = round(center_x / orig_w, 4) if orig_w > 0 else 0
  111. norm_cy = round(center_y / orig_h, 4) if orig_h > 0 else 0
  112. norm_w = round(width / orig_w, 4) if orig_w > 0 else 0
  113. norm_h = round(height / orig_h, 4) if orig_h > 0 else 0
  114. detail_ws.write_row(row_idx, 0, [
  115. record['image_path'],
  116. len(record['detections']),
  117. detection['confidence'],
  118. '', # Avg Confidence占位
  119. orig_w, # 新增原始宽度
  120. orig_h, # 新增原始高度
  121. center_x,
  122. center_y,
  123. width,
  124. height,
  125. f"{norm_cx},{norm_cy},{norm_w},{norm_h}",
  126. '',
  127. ''
  128. ])
  129. row_idx += 1
  130. # 仅在需要时计算指标
  131. if calculate_metrics and ground_truth_file:
  132. metrics = self.calculate_metrics(
  133. self.parse_excel_ground_truth(ground_truth_file)
  134. )
  135. metrics_ws = workbook.add_worksheet('Model Evaluation')
  136. metrics_headers = ['True Positive', 'False Positive', 'True Negative', 'False Negative', 'False Positive Rate', 'False Negative Rate']
  137. metrics_ws.write_row(0, 0, metrics_headers)
  138. metrics_ws.write_row(1, 0, [
  139. metrics['true_positive'],
  140. metrics['false_positive'],
  141. metrics['true_negative'],
  142. metrics['false_negative'],
  143. metrics['fpr'],
  144. metrics['fnr']
  145. ])
  146. workbook.close()
  147. def parse_excel_ground_truth(self, excel_path: str) -> Dict[str, bool]:
  148. """
  149. 从Excel文件解析地面真实值
  150. :param excel_path: detection_report_0.3.xlsx文件路径
  151. :return: 包含图像路径和真实检测结果的字典
  152. """
  153. wb = load_workbook(excel_path)
  154. ws = wb.active
  155. ground_truth = {}
  156. for row in ws.iter_rows(min_row=2, values_only=True):
  157. image_path = row[0]
  158. # 假设第二列是真实标签(例如1表示存在目标,0表示不存在)
  159. actual_value = bool(row[1]) if len(row) > 1 else False
  160. ground_truth[image_path] = actual_value
  161. return ground_truth