split_dataset.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os
  2. import re
  3. import random
  4. import shutil
  5. from collections import defaultdict
  6. # 配置路径
  7. SRC_DIR = os.path.join(os.path.dirname(__file__), 'UAV')
  8. TEST_DIR = os.path.join(os.path.dirname(__file__), 'TEST')
  9. # 创建目标目录
  10. os.makedirs(TEST_DIR, exist_ok=True)
  11. def split_dataset():
  12. # 1. 扫描源目录
  13. if not os.path.exists(SRC_DIR):
  14. print(f"源目录不存在: {SRC_DIR}")
  15. return
  16. all_files = [f for f in os.listdir(SRC_DIR) if f.lower().endswith('.jpg')]
  17. if not all_files:
  18. print("未找到JPG文件")
  19. return
  20. # 2. 按前缀分组
  21. pattern = re.compile(r'(\d{2}_\d{2})')
  22. groups = defaultdict(list)
  23. for filename in all_files:
  24. match = pattern.match(filename)
  25. if match:
  26. group_key = match.group(1)
  27. groups[group_key].append(filename)
  28. # 3. 统计并复制文件
  29. total_copied = 0
  30. for group_key, files in groups.items():
  31. group_size = len(files)
  32. test_count = min(round(group_size * 0.1), 10)
  33. test_count = max(1, test_count)
  34. # 随机抽样
  35. selected = random.sample(files, test_count)
  36. # 直接使用TEST目录
  37. for f in selected:
  38. src = os.path.join(SRC_DIR, f)
  39. dst = os.path.join(TEST_DIR, f)
  40. if not os.path.exists(dst):
  41. shutil.copy(src, dst)
  42. total_copied += 1
  43. # 输出统计信息
  44. print(f"处理完成:\n"
  45. f"- 共发现 {len(groups)} 个分组\n"
  46. f"- 总计复制 {total_copied} 张测试图片\n"
  47. f"- 输出目录: {TEST_DIR}")
  48. if __name__ == '__main__':
  49. split_dataset()