generate_test_data.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试数据生成器
  5. 用于生成模拟的测试数据,方便测试模型评估功能
  6. 作者: AI Assistant
  7. 创建时间: 2024
  8. """
  9. import os
  10. import cv2
  11. import numpy as np
  12. import json
  13. import random
  14. from datetime import datetime
  15. from typing import List, Dict, Tuple
  16. class TestDataGenerator:
  17. """
  18. 测试数据生成器
  19. 生成模拟的UAV和鸟类图像用于测试
  20. """
  21. def __init__(self, output_dir: str = None):
  22. """
  23. 初始化生成器
  24. Args:
  25. output_dir: 输出目录
  26. """
  27. if output_dir is None:
  28. output_dir = os.path.join(os.path.dirname(__file__), 'Data')
  29. self.output_dir = output_dir
  30. self.image_size = (640, 480) # 默认图像尺寸
  31. # 创建目录结构
  32. self.birds_dir = os.path.join(output_dir, 'BIRDS')
  33. self.uav_dir = os.path.join(output_dir, 'UAV')
  34. self.others_dir = os.path.join(output_dir, 'OTHERS')
  35. os.makedirs(self.birds_dir, exist_ok=True)
  36. os.makedirs(self.uav_dir, exist_ok=True)
  37. os.makedirs(self.others_dir, exist_ok=True)
  38. def generate_bird_image(self, image_id: int) -> str:
  39. """
  40. 生成鸟类图像(模拟)
  41. Args:
  42. image_id: 图像ID
  43. Returns:
  44. 生成的图像路径
  45. """
  46. # 创建蓝天背景
  47. img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
  48. img[:, :] = [135, 206, 235] # 天蓝色背景
  49. # 添加一些云朵
  50. for _ in range(random.randint(2, 5)):
  51. center = (random.randint(50, self.image_size[0]-50),
  52. random.randint(50, self.image_size[1]-50))
  53. radius = random.randint(20, 60)
  54. cv2.circle(img, center, radius, (255, 255, 255), -1)
  55. # 绘制鸟类形状(简化的V形)
  56. bird_x = random.randint(100, self.image_size[0]-100)
  57. bird_y = random.randint(100, self.image_size[1]-100)
  58. bird_size = random.randint(15, 40)
  59. # 绘制鸟的翅膀(V形)
  60. pts1 = np.array([[bird_x-bird_size, bird_y],
  61. [bird_x, bird_y-bird_size//2],
  62. [bird_x, bird_y+bird_size//2]], np.int32)
  63. pts2 = np.array([[bird_x+bird_size, bird_y],
  64. [bird_x, bird_y-bird_size//2],
  65. [bird_x, bird_y+bird_size//2]], np.int32)
  66. cv2.fillPoly(img, [pts1], (50, 50, 50)) # 深灰色鸟
  67. cv2.fillPoly(img, [pts2], (50, 50, 50))
  68. # 添加噪声
  69. noise = np.random.normal(0, 10, img.shape).astype(np.uint8)
  70. img = cv2.add(img, noise)
  71. # 保存图像
  72. filename = f"bird_{image_id:04d}.jpg"
  73. filepath = os.path.join(self.birds_dir, filename)
  74. cv2.imwrite(filepath, img)
  75. return filepath
  76. def generate_uav_image(self, image_id: int) -> str:
  77. """
  78. 生成无人机图像(模拟)
  79. Args:
  80. image_id: 图像ID
  81. Returns:
  82. 生成的图像路径
  83. """
  84. # 创建蓝天背景
  85. img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
  86. img[:, :] = [135, 206, 235] # 天蓝色背景
  87. # 添加一些云朵
  88. for _ in range(random.randint(1, 3)):
  89. center = (random.randint(50, self.image_size[0]-50),
  90. random.randint(50, self.image_size[1]-50))
  91. radius = random.randint(30, 80)
  92. cv2.circle(img, center, radius, (255, 255, 255), -1)
  93. # 绘制无人机形状(简化的十字形)
  94. uav_x = random.randint(100, self.image_size[0]-100)
  95. uav_y = random.randint(100, self.image_size[1]-100)
  96. uav_size = random.randint(20, 50)
  97. # 绘制无人机机身(矩形)
  98. cv2.rectangle(img,
  99. (uav_x-uav_size//4, uav_y-uav_size//8),
  100. (uav_x+uav_size//4, uav_y+uav_size//8),
  101. (80, 80, 80), -1)
  102. # 绘制螺旋桨臂
  103. cv2.line(img, (uav_x-uav_size, uav_y), (uav_x+uav_size, uav_y), (60, 60, 60), 3)
  104. cv2.line(img, (uav_x, uav_y-uav_size), (uav_x, uav_y+uav_size), (60, 60, 60), 3)
  105. # 绘制螺旋桨(圆形)
  106. for dx, dy in [(-uav_size, 0), (uav_size, 0), (0, -uav_size), (0, uav_size)]:
  107. cv2.circle(img, (uav_x+dx, uav_y+dy), uav_size//6, (40, 40, 40), -1)
  108. # 添加噪声
  109. noise = np.random.normal(0, 8, img.shape).astype(np.uint8)
  110. img = cv2.add(img, noise)
  111. # 保存图像
  112. filename = f"uav_{image_id:04d}.jpg"
  113. filepath = os.path.join(self.uav_dir, filename)
  114. cv2.imwrite(filepath, img)
  115. return filepath
  116. def generate_other_image(self, image_id: int) -> str:
  117. """
  118. 生成其他类型图像(负样本)
  119. Args:
  120. image_id: 图像ID
  121. Returns:
  122. 生成的图像路径
  123. """
  124. # 随机选择背景类型
  125. bg_type = random.choice(['sky', 'landscape', 'urban'])
  126. if bg_type == 'sky':
  127. # 纯天空背景
  128. img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
  129. img[:, :] = [135, 206, 235] # 天蓝色
  130. # 添加云朵
  131. for _ in range(random.randint(3, 8)):
  132. center = (random.randint(0, self.image_size[0]),
  133. random.randint(0, self.image_size[1]))
  134. radius = random.randint(40, 100)
  135. cv2.circle(img, center, radius, (255, 255, 255), -1)
  136. elif bg_type == 'landscape':
  137. # 风景背景
  138. img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
  139. # 天空部分
  140. img[:self.image_size[1]//2, :] = [135, 206, 235]
  141. # 地面部分
  142. img[self.image_size[1]//2:, :] = [34, 139, 34] # 森林绿
  143. # 添加一些树木形状
  144. for _ in range(random.randint(5, 10)):
  145. tree_x = random.randint(0, self.image_size[0])
  146. tree_y = random.randint(self.image_size[1]//2, self.image_size[1])
  147. tree_height = random.randint(30, 80)
  148. cv2.rectangle(img,
  149. (tree_x-5, tree_y),
  150. (tree_x+5, tree_y-tree_height),
  151. (139, 69, 19), -1) # 棕色树干
  152. cv2.circle(img, (tree_x, tree_y-tree_height), 15, (0, 100, 0), -1) # 绿色树冠
  153. else: # urban
  154. # 城市背景
  155. img = np.ones((self.image_size[1], self.image_size[0], 3), dtype=np.uint8)
  156. img[:, :] = [169, 169, 169] # 灰色背景
  157. # 添加建筑物
  158. for _ in range(random.randint(3, 6)):
  159. building_x = random.randint(0, self.image_size[0]-50)
  160. building_y = random.randint(self.image_size[1]//3, self.image_size[1])
  161. building_w = random.randint(30, 80)
  162. building_h = random.randint(50, 150)
  163. cv2.rectangle(img,
  164. (building_x, building_y),
  165. (building_x+building_w, building_y-building_h),
  166. (105, 105, 105), -1) # 深灰色建筑
  167. # 添加噪声
  168. noise = np.random.normal(0, 12, img.shape).astype(np.uint8)
  169. img = cv2.add(img, noise)
  170. # 保存图像
  171. filename = f"other_{image_id:04d}.jpg"
  172. filepath = os.path.join(self.others_dir, filename)
  173. cv2.imwrite(filepath, img)
  174. return filepath
  175. def generate_dataset(self, num_birds: int = 50, num_uavs: int = 50, num_others: int = 100) -> Dict:
  176. """
  177. 生成完整的测试数据集
  178. Args:
  179. num_birds: 鸟类图像数量
  180. num_uavs: 无人机图像数量
  181. num_others: 其他图像数量
  182. Returns:
  183. 数据集信息字典
  184. """
  185. print(f"开始生成测试数据集...")
  186. print(f"鸟类图像: {num_birds} 张")
  187. print(f"无人机图像: {num_uavs} 张")
  188. print(f"其他图像: {num_others} 张")
  189. print(f"输出目录: {self.output_dir}")
  190. dataset_info = {
  191. 'creation_time': datetime.now().isoformat(),
  192. 'total_images': num_birds + num_uavs + num_others,
  193. 'categories': {
  194. 'birds': {'count': num_birds, 'files': []},
  195. 'uavs': {'count': num_uavs, 'files': []},
  196. 'others': {'count': num_others, 'files': []}
  197. }
  198. }
  199. # 生成鸟类图像
  200. print("\n生成鸟类图像...")
  201. for i in range(num_birds):
  202. filepath = self.generate_bird_image(i)
  203. dataset_info['categories']['birds']['files'].append(filepath)
  204. if (i + 1) % 10 == 0:
  205. print(f"已生成 {i + 1}/{num_birds} 张鸟类图像")
  206. # 生成无人机图像
  207. print("\n生成无人机图像...")
  208. for i in range(num_uavs):
  209. filepath = self.generate_uav_image(i)
  210. dataset_info['categories']['uavs']['files'].append(filepath)
  211. if (i + 1) % 10 == 0:
  212. print(f"已生成 {i + 1}/{num_uavs} 张无人机图像")
  213. # 生成其他图像
  214. print("\n生成其他图像...")
  215. for i in range(num_others):
  216. filepath = self.generate_other_image(i)
  217. dataset_info['categories']['others']['files'].append(filepath)
  218. if (i + 1) % 20 == 0:
  219. print(f"已生成 {i + 1}/{num_others} 张其他图像")
  220. # 保存数据集信息
  221. info_file = os.path.join(self.output_dir, 'dataset_info.json')
  222. with open(info_file, 'w', encoding='utf-8') as f:
  223. json.dump(dataset_info, f, ensure_ascii=False, indent=2)
  224. print(f"\n数据集生成完成!")
  225. print(f"总计生成 {dataset_info['total_images']} 张图像")
  226. print(f"数据集信息已保存到: {info_file}")
  227. return dataset_info
  228. def create_annotation_file(self, dataset_info: Dict) -> str:
  229. """
  230. 创建标注文件
  231. Args:
  232. dataset_info: 数据集信息
  233. Returns:
  234. 标注文件路径
  235. """
  236. annotations = {
  237. 'info': {
  238. 'description': 'Generated test dataset for UAV and bird detection',
  239. 'version': '1.0',
  240. 'creation_date': dataset_info['creation_time']
  241. },
  242. 'images': []
  243. }
  244. # 添加鸟类图像标注
  245. for filepath in dataset_info['categories']['birds']['files']:
  246. annotations['images'].append({
  247. 'path': filepath,
  248. 'label': 'bird',
  249. 'has_target': True,
  250. 'bboxes': [] # 简化版本,不包含具体边界框
  251. })
  252. # 添加无人机图像标注
  253. for filepath in dataset_info['categories']['uavs']['files']:
  254. annotations['images'].append({
  255. 'path': filepath,
  256. 'label': 'uav',
  257. 'has_target': True,
  258. 'bboxes': []
  259. })
  260. # 添加其他图像标注
  261. for filepath in dataset_info['categories']['others']['files']:
  262. annotations['images'].append({
  263. 'path': filepath,
  264. 'label': 'others',
  265. 'has_target': False,
  266. 'bboxes': []
  267. })
  268. # 保存标注文件
  269. annotation_file = os.path.join(self.output_dir, 'annotations.json')
  270. with open(annotation_file, 'w', encoding='utf-8') as f:
  271. json.dump(annotations, f, ensure_ascii=False, indent=2)
  272. print(f"标注文件已保存到: {annotation_file}")
  273. return annotation_file
  274. def main():
  275. """
  276. 主函数 - 生成测试数据
  277. """
  278. print("测试数据生成器")
  279. print("=" * 50)
  280. # 配置参数
  281. output_dir = r"d:\PythonProject\R360-UAVmodelTool\Data"
  282. num_birds = 30
  283. num_uavs = 30
  284. num_others = 60
  285. # 初始化生成器
  286. generator = TestDataGenerator(output_dir)
  287. # 生成数据集
  288. dataset_info = generator.generate_dataset(num_birds, num_uavs, num_others)
  289. # 创建标注文件
  290. annotation_file = generator.create_annotation_file(dataset_info)
  291. print("\n数据集结构:")
  292. print(f"├── BIRDS/ ({num_birds} 张图像)")
  293. print(f"├── UAV/ ({num_uavs} 张图像)")
  294. print(f"├── OTHERS/ ({num_others} 张图像)")
  295. print(f"├── dataset_info.json")
  296. print(f"└── annotations.json")
  297. print("\n现在可以运行 model_evaluator.py 进行模型评估!")
  298. if __name__ == "__main__":
  299. main()