12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import os
- import re
- import random
- import shutil
- from collections import defaultdict
- # 配置路径
- SRC_DIR = os.path.join(os.path.dirname(__file__), 'UAV')
- TEST_DIR = os.path.join(os.path.dirname(__file__), 'TEST')
- # 创建目标目录
- os.makedirs(TEST_DIR, exist_ok=True)
- def split_dataset():
- # 1. 扫描源目录
- if not os.path.exists(SRC_DIR):
- print(f"源目录不存在: {SRC_DIR}")
- return
-
- all_files = [f for f in os.listdir(SRC_DIR) if f.lower().endswith('.jpg')]
- if not all_files:
- print("未找到JPG文件")
- return
- # 2. 按前缀分组
- pattern = re.compile(r'(\d{2}_\d{2})')
- groups = defaultdict(list)
-
- for filename in all_files:
- match = pattern.match(filename)
- if match:
- group_key = match.group(1)
- groups[group_key].append(filename)
- # 3. 统计并复制文件
- total_copied = 0
- for group_key, files in groups.items():
- group_size = len(files)
- test_count = min(round(group_size * 0.1), 10)
- test_count = max(1, test_count)
-
- # 随机抽样
- selected = random.sample(files, test_count)
-
- # 直接使用TEST目录
- for f in selected:
- src = os.path.join(SRC_DIR, f)
- dst = os.path.join(TEST_DIR, f)
-
- if not os.path.exists(dst):
- shutil.copy(src, dst)
- total_copied += 1
- # 输出统计信息
- print(f"处理完成:\n"
- f"- 共发现 {len(groups)} 个分组\n"
- f"- 总计复制 {total_copied} 张测试图片\n"
- f"- 输出目录: {TEST_DIR}")
- if __name__ == '__main__':
- split_dataset()
|