main.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #!/usr/bin/env python3
  2. """UAV检测系统主入口"""
  3. import argparse
  4. import sys
  5. import os
  6. from typing import Optional
  7. from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
  8. from core.detector import UAVDetector
  9. from utils.file_manager import FileManager
  10. def parse_arguments() -> argparse.Namespace:
  11. """解析命令行参数"""
  12. parser = argparse.ArgumentParser(
  13. description='UAV Detection using ONNX model',
  14. formatter_class=argparse.RawDescriptionHelpFormatter,
  15. epilog="""
  16. 示例用法:
  17. python main.py --input /path/to/images --threshold 0.6
  18. python main.py --gui
  19. python main.py --input single_image.jpg --output custom_output
  20. """
  21. )
  22. # 输入输出参数
  23. parser.add_argument('--input', type=str, help='输入图像路径或目录')
  24. parser.add_argument('--output', type=str, default=None, help='输出目录路径')
  25. parser.add_argument('--gui', action='store_true', help='启用图形界面选择输入目录')
  26. # 模型参数
  27. parser.add_argument('--model', type=str,
  28. default='D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx',
  29. help='ONNX模型文件路径')
  30. # 检测参数
  31. parser.add_argument('--threshold', type=float, default=0.5,
  32. help='检测置信度阈值 (默认: 0.5)')
  33. parser.add_argument('--iou-threshold', type=float, default=0.4,
  34. help='NMS IoU阈值 (默认: 0.4)')
  35. parser.add_argument('--max-bbox-ratio', type=float, default=0.5,
  36. help='检测框最大面积比例阈值 (默认: 0.5)')
  37. # 输出选项
  38. parser.add_argument('--save-empty', action='store_true',
  39. help='是否保存未检测到目标的图片')
  40. parser.add_argument('--no-timestamp', action='store_true',
  41. help='输出目录不添加时间戳')
  42. # 其他选项
  43. parser.add_argument('--verbose', '-v', action='store_true',
  44. help='显示详细信息')
  45. return parser.parse_args()
  46. def create_config_from_args(args: argparse.Namespace) -> AppConfig:
  47. """从命令行参数创建配置"""
  48. model_config = ModelConfig(
  49. model_path=args.model
  50. )
  51. detection_config = DetectionConfig(
  52. confidence_threshold=args.threshold,
  53. iou_threshold=args.iou_threshold,
  54. max_bbox_ratio=args.max_bbox_ratio
  55. )
  56. output_config = OutputConfig(
  57. output_dir=args.output,
  58. save_empty=args.save_empty,
  59. create_timestamp_dir=not args.no_timestamp
  60. )
  61. return AppConfig(
  62. model=model_config,
  63. detection=detection_config,
  64. output=output_config
  65. )
  66. def validate_input(input_path: str) -> bool:
  67. """验证输入路径"""
  68. if not input_path:
  69. return False
  70. if not os.path.exists(input_path):
  71. print(f"错误:输入路径不存在: {input_path}")
  72. return False
  73. if os.path.isfile(input_path):
  74. supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif')
  75. if not input_path.lower().endswith(supported_formats):
  76. print(f"错误:不支持的图像格式: {input_path}")
  77. return False
  78. return True
  79. def main():
  80. """主函数"""
  81. try:
  82. # 解析命令行参数
  83. args = parse_arguments()
  84. # GUI模式处理
  85. if args.gui:
  86. input_path = FileManager.select_input_directory()
  87. if not input_path:
  88. print("未选择目录,程序退出")
  89. return 1
  90. args.input = input_path
  91. # 验证输入参数
  92. if not args.input:
  93. print("错误:请提供输入路径或使用 --gui 参数选择目录")
  94. return 1
  95. if not validate_input(args.input):
  96. return 1
  97. # 创建配置
  98. config = create_config_from_args(args)
  99. # 确定输入目录(用于输出目录命名)
  100. input_dir = args.input if os.path.isdir(args.input) else None
  101. # 初始化检测器
  102. if args.verbose:
  103. print("正在初始化检测器...")
  104. detector = UAVDetector(config, input_dir)
  105. if args.verbose:
  106. model_info = detector.get_model_info()
  107. print(f"模型信息:")
  108. print(f" 路径: {model_info['model_path']}")
  109. print(f" 输入形状: {model_info['input_shape']}")
  110. print(f" 执行提供者: {model_info['providers']}")
  111. print(f" 置信度阈值: {model_info['confidence_threshold']}")
  112. print()
  113. # 处理图像
  114. if os.path.isfile(args.input):
  115. # 单张图像处理
  116. print(f"处理单张图像: {args.input}")
  117. detections = detector.process_image(args.input)
  118. print(f"检测完成,发现 {detections} 个目标")
  119. # 生成报告
  120. detector._generate_reports()
  121. else:
  122. # 批量处理
  123. print(f"开始批量处理目录: {args.input}")
  124. results = detector.process_batch(args.input)
  125. print(f"\n批量处理完成!")
  126. print(f" 处理图像数: {results['processed_images']}")
  127. print(f" 检测目标总数: {results['total_detections']}")
  128. print(f" 有检测结果的图像: {results['images_with_detections']}")
  129. # 显示统计信息
  130. if args.verbose:
  131. stats = detector.get_statistics()
  132. print(f"\n统计信息:")
  133. print(f" 检测率: {stats['detection_rate']:.2%}")
  134. print(f" 平均每张图像检测数: {stats['avg_detections_per_image']:.2f}")
  135. print(f"\n结果保存在: {detector.output_dir}")
  136. return 0
  137. except KeyboardInterrupt:
  138. print("\n用户中断操作")
  139. return 1
  140. except Exception as e:
  141. print(f"程序执行出错: {e}")
  142. if args.verbose if 'args' in locals() else False:
  143. import traceback
  144. traceback.print_exc()
  145. return 1
  146. if __name__ == '__main__':
  147. sys.exit(main())