소스 검색

refactor: 重构项目结构并优化代码组织

将原有代码按照功能模块重新组织到src目录下,包括core、utils和analysis模块。同时更新了.gitignore和README.md文件,完善了项目结构说明和开发指南。删除旧文件并将其移动到新目录中,以提升代码的可维护性和可读性。
Hannnk 3 주 전
부모
커밋
d23087c3ed

+ 45 - 2
.gitignore

@@ -1,3 +1,46 @@
-*.jpg
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+venv/
+ENV/
+env/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# Project specific
+outputs/
 *.onnx
-/venv
+*.pkl
+*.log
+*.csv
+*.xlsx
+*.jpg
+
+
+# OS
+.DS_Store
+Thumbs.db

+ 41 - 12
README.md

@@ -7,28 +7,37 @@
 - 误报过滤机制(ORB特征匹配)
 - 多格式检测报告生成(CSV/Excel)
 
+## 项目结构
+```
+Model/
+├── src/                    # 源代码目录
+│   ├── core/              # 核心功能模块
+│   ├── utils/             # 工具函数
+│   └── analysis/          # 分析工具
+├── configs/               # 配置文件目录
+├── data/                  # 数据目录
+│   ├── false_positive/   # 误报样本
+│   └── models/           # 模型文件
+├── tests/                # 测试目录
+├── outputs/              # 输出目录
+└── docs/                 # 文档目录
+```
+
 ## 环境要求
 - Python 3.8+ 
 - 依赖库:
-```
-onnxruntime-gpu==1.16.0
-opencv-python==4.8.0
-numpy==1.24.3
-tqdm==4.66.1
-openpyxl==3.1.2
+```bash
+pip install -r requirements.txt
 ```
 
 ## 快速开始
 ```python
-# 安装依赖
-pip install -r requirements.txt
-
 # 单张图片检测
-python inference.py --input test.jpg --threshold 0.5 \
+python src/core/inference.py --input test.jpg --threshold 0.5 \
     --output results --report excel
 
 # 批量检测目录
-python inference.py --input UAV/ --max-bbox-ratio 0.1 \
+python src/core/inference.py --input UAV/ --max-bbox-ratio 0.1 \
     --output batch_results --report all
 ```
 
@@ -56,4 +65,24 @@ Q: 如何启用GPU加速?
 A: 确保安装onnxruntime-gpu并配置CUDA环境,程序会自动优先使用GPU
 
 Q: 误报特征库如何更新?
-A: 将误报样本放入false_positive目录,程序启动时会自动加载
+A: 将误报样本放入data/false_positive目录,程序启动时会自动加载
+
+## 开发指南
+1. 代码规范
+   - 遵循PEP 8编码规范
+   - 使用类型注解
+   - 编写单元测试
+
+2. 测试
+```bash
+# 运行测试
+pytest tests/
+
+# 生成测试覆盖率报告
+pytest --cov=src tests/
+```
+
+3. 文档
+   - 使用Google风格文档字符串
+   - 保持README.md更新
+   - 重要功能添加使用示例

+ 0 - 6
onnx_check.py

@@ -1,6 +0,0 @@
-import onnxruntime
-print(onnxruntime.__version__)
-print(onnxruntime.get_device() ) # 如果得到的输出结果是GPU,所以按理说是找到了GPU的
-ort_session = onnxruntime.InferenceSession("fly.onnx",
-providers=['CUDAExecutionProvider'])
-print(ort_session.get_providers())

+ 28 - 0
reorganize.py

@@ -0,0 +1,28 @@
+import os
+import shutil
+
+# 定义目录结构
+dirs = {
+    'src/core': ['inference.py', 'feature_extractor.py', 'image_validation.py', 'ali_image_validation.py'],
+    'src/utils': ['bbox_visualizer.py', 'detection_visualizer.py', 'split_dataset.py', 'update_labels.py'],
+    'src/analysis': ['hotspot_analyzer.py', 'report_generator.py', 'analysis_report.py'],
+    'data/models': ['UAV-250321.onnx'],
+    'data/false_positive': [],
+    'tests': [],
+    'outputs/archive': []
+}
+
+# 创建目录结构
+for dir_path in dirs.keys():
+    os.makedirs(dir_path, exist_ok=True)
+
+# 移动文件
+for target_dir, files in dirs.items():
+    for file in files:
+        if os.path.exists(file):
+            shutil.move(file, os.path.join(target_dir, file))
+            print(f'Moved {file} to {target_dir}/')
+        else:
+            print(f'Warning: {file} not found')
+
+print('Project reorganization completed!') 

+ 7 - 0
requirements.txt

@@ -0,0 +1,7 @@
+onnxruntime-gpu==1.16.0
+opencv-python==4.8.0
+numpy==1.24.3
+tqdm==4.66.1
+openpyxl==3.1.2
+pytest==7.4.0
+pytest-cov==4.1.0 

+ 1 - 0
src/__init__.py

@@ -0,0 +1 @@
+# 使src成为一个Python包 

+ 1 - 0
src/analysis/__init__.py

@@ -0,0 +1 @@
+# 使analysis成为一个Python包 

+ 0 - 0
analysis_report.py → src/analysis/analysis_report.py


+ 0 - 0
hotspot_analyzer.py → src/analysis/hotspot_analyzer.py


+ 0 - 0
report_generator.py → src/analysis/report_generator.py


+ 0 - 0
ali_image_validation.py → src/core/ali_image_validation.py


+ 0 - 0
feature_extractor.py → src/core/feature_extractor.py


+ 0 - 0
image_validation.py → src/core/image_validation.py


+ 2 - 2
inference.py → src/core/inference.py

@@ -9,7 +9,7 @@ from typing import List, Dict, Optional
 import time
 from datetime import datetime
 from tqdm import tqdm
-from report_generator import ReportGenerator
+from ..analysis.report_generator import ReportGenerator
 class ONNXDetector:
     def select_input_directory(self) -> Optional[str]:
         """图形界面选择输入目录"""
@@ -20,7 +20,7 @@ class ONNXDetector:
         folder_path = filedialog.askdirectory(title='选择输入目录')
         return folder_path if folder_path else None
 
-    def __init__(self, model_path: str = 'UAV-250321.onnx', threshold: float = 0.5, 
+    def __init__(self, model_path: str = 'data/models/UAV-250321.onnx', threshold: float = 0.5, 
                  output_dir: str = "None", save_empty: bool = False, 
                  max_bbox_ratio: float = 0.5,  # 已有该参数
                  input_dir: str = "None"):

+ 0 - 0
bbox_visualizer.py → src/utils/bbox_visualizer.py


+ 0 - 0
detection_visualizer.py → src/utils/detection_visualizer.py


+ 0 - 0
split_dataset.py → src/utils/split_dataset.py


+ 0 - 0
update_labels.py → src/utils/update_labels.py


+ 47 - 0
tests/test_inference.py

@@ -0,0 +1,47 @@
+import pytest
+import os
+import cv2
+import numpy as np
+from src.core.inference import InferenceEngine
+
+class TestInference:
+    @pytest.fixture
+    def inference_engine(self):
+        model_path = os.path.join('data', 'models', 'UAV-250321.onnx')
+        return InferenceEngine(model_path)
+
+    def test_single_image_inference(self, inference_engine):
+        # 创建一个测试图像
+        test_image = np.zeros((640, 640, 3), dtype=np.uint8)
+        test_image[100:200, 100:200] = 255  # 添加一个白色方块
+
+        # 执行推理
+        results = inference_engine.infer(test_image)
+        
+        # 验证结果格式
+        assert isinstance(results, list)
+        for result in results:
+            assert 'bbox' in result
+            assert 'confidence' in result
+            assert 'class_id' in result
+
+    def test_batch_inference(self, inference_engine):
+        # 创建测试图像列表
+        test_images = [
+            np.zeros((640, 640, 3), dtype=np.uint8) for _ in range(3)
+        ]
+        
+        # 执行批量推理
+        results = inference_engine.batch_infer(test_images)
+        
+        # 验证结果
+        assert len(results) == len(test_images)
+        assert all(isinstance(r, list) for r in results)
+
+    def test_invalid_input(self, inference_engine):
+        # 测试无效输入
+        with pytest.raises(ValueError):
+            inference_engine.infer(None)
+        
+        with pytest.raises(ValueError):
+            inference_engine.infer(np.zeros((100, 100), dtype=np.uint8))  # 灰度图像