123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- #!/usr/bin/env python3
- """UAV检测系统主入口"""
- import argparse
- import sys
- import os
- from typing import Optional
- from config.config import AppConfig, ModelConfig, DetectionConfig, OutputConfig
- from core.detector import UAVDetector
- from utils.file_manager import FileManager
- def parse_arguments() -> argparse.Namespace:
- """解析命令行参数"""
- parser = argparse.ArgumentParser(
- description='UAV Detection using ONNX model',
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 示例用法:
- python main.py --input /path/to/images --threshold 0.6
- python main.py --gui
- python main.py --input single_image.jpg --output custom_output
- """
- )
-
- # 输入输出参数
- parser.add_argument('--input', type=str, help='输入图像路径或目录')
- parser.add_argument('--output', type=str, default=None, help='输出目录路径')
- parser.add_argument('--gui', action='store_true', help='启用图形界面选择输入目录')
-
- # 模型参数
- parser.add_argument('--model', type=str,
- default='D:/PythonProject/Model/Data/models/250411_Anti_UAV.onnx',
- help='ONNX模型文件路径')
-
- # 检测参数
- parser.add_argument('--threshold', type=float, default=0.5,
- help='检测置信度阈值 (默认: 0.5)')
- parser.add_argument('--iou-threshold', type=float, default=0.4,
- help='NMS IoU阈值 (默认: 0.4)')
- parser.add_argument('--max-bbox-ratio', type=float, default=0.5,
- help='检测框最大面积比例阈值 (默认: 0.5)')
-
- # 输出选项
- parser.add_argument('--save-empty', action='store_true',
- help='是否保存未检测到目标的图片')
- parser.add_argument('--no-timestamp', action='store_true',
- help='输出目录不添加时间戳')
-
- # 其他选项
- parser.add_argument('--verbose', '-v', action='store_true',
- help='显示详细信息')
-
- return parser.parse_args()
- def create_config_from_args(args: argparse.Namespace) -> AppConfig:
- """从命令行参数创建配置"""
- model_config = ModelConfig(
- model_path=args.model
- )
-
- detection_config = DetectionConfig(
- confidence_threshold=args.threshold,
- iou_threshold=args.iou_threshold,
- max_bbox_ratio=args.max_bbox_ratio
- )
-
- output_config = OutputConfig(
- output_dir=args.output,
- save_empty=args.save_empty,
- create_timestamp_dir=not args.no_timestamp
- )
-
- return AppConfig(
- model=model_config,
- detection=detection_config,
- output=output_config
- )
- def validate_input(input_path: str) -> bool:
- """验证输入路径"""
- if not input_path:
- return False
-
- if not os.path.exists(input_path):
- print(f"错误:输入路径不存在: {input_path}")
- return False
-
- if os.path.isfile(input_path):
- supported_formats = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif')
- if not input_path.lower().endswith(supported_formats):
- print(f"错误:不支持的图像格式: {input_path}")
- return False
-
- return True
- def main():
- """主函数"""
- try:
- # 解析命令行参数
- args = parse_arguments()
-
- # GUI模式处理
- if args.gui:
- input_path = FileManager.select_input_directory()
- if not input_path:
- print("未选择目录,程序退出")
- return 1
- args.input = input_path
-
- # 验证输入参数
- if not args.input:
- print("错误:请提供输入路径或使用 --gui 参数选择目录")
- return 1
-
- if not validate_input(args.input):
- return 1
-
- # 创建配置
- config = create_config_from_args(args)
-
- # 确定输入目录(用于输出目录命名)
- input_dir = args.input if os.path.isdir(args.input) else None
-
- # 初始化检测器
- if args.verbose:
- print("正在初始化检测器...")
-
- detector = UAVDetector(config, input_dir)
-
- if args.verbose:
- model_info = detector.get_model_info()
- print(f"模型信息:")
- print(f" 路径: {model_info['model_path']}")
- print(f" 输入形状: {model_info['input_shape']}")
- print(f" 执行提供者: {model_info['providers']}")
- print(f" 置信度阈值: {model_info['confidence_threshold']}")
- print()
-
- # 处理图像
- if os.path.isfile(args.input):
- # 单张图像处理
- print(f"处理单张图像: {args.input}")
- detections = detector.process_image(args.input)
- print(f"检测完成,发现 {detections} 个目标")
-
- # 生成报告
- detector._generate_reports()
-
- else:
- # 批量处理
- print(f"开始批量处理目录: {args.input}")
- results = detector.process_batch(args.input)
-
- print(f"\n批量处理完成!")
- print(f" 处理图像数: {results['processed_images']}")
- print(f" 检测目标总数: {results['total_detections']}")
- print(f" 有检测结果的图像: {results['images_with_detections']}")
-
- # 显示统计信息
- if args.verbose:
- stats = detector.get_statistics()
- print(f"\n统计信息:")
- print(f" 检测率: {stats['detection_rate']:.2%}")
- print(f" 平均每张图像检测数: {stats['avg_detections_per_image']:.2f}")
-
- print(f"\n结果保存在: {detector.output_dir}")
- return 0
-
- except KeyboardInterrupt:
- print("\n用户中断操作")
- return 1
- except Exception as e:
- print(f"程序执行出错: {e}")
- if args.verbose if 'args' in locals() else False:
- import traceback
- traceback.print_exc()
- return 1
- if __name__ == '__main__':
- sys.exit(main())
|