target_based_classifier.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. import tkinter as tk
  2. from tkinter import ttk, filedialog, messagebox
  3. from PIL import Image, ImageTk, ImageEnhance
  4. import os
  5. import shutil
  6. from pathlib import Path
  7. import numpy as np
  8. from collections import Counter
  9. import cv2
  10. class TargetBasedClassifier:
  11. def __init__(self, root):
  12. self.root = root
  13. self.root.title("基于目标的图片分类器")
  14. self.root.geometry("1400x900")
  15. # 当前工作目录(包含original、targets、annotated三个文件夹)
  16. self.work_directory = None
  17. # 三个子目录路径
  18. self.original_dir = None
  19. self.targets_dir = None
  20. self.annotated_dir = None
  21. # 图片列表和索引
  22. self.image_list = [] # 存储original文件夹中的图片路径
  23. self.current_index = 0
  24. # 当前显示的图片信息
  25. self.current_original_path = None
  26. self.current_targets = [] # 当前图片对应的目标切片列表
  27. # 图片处理参数
  28. self.zoom_factor = 1.0
  29. self.rotation_angle = 0
  30. self.brightness_factor = 1.0
  31. # 操作历史记录
  32. self.operation_history = []
  33. self.setup_ui()
  34. self.setup_keyboard_shortcuts()
  35. def setup_ui(self):
  36. # 创建主框架
  37. main_frame = ttk.Frame(self.root)
  38. main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
  39. # 左侧控制面板
  40. control_frame = ttk.Frame(main_frame)
  41. control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5)
  42. # 选择工作目录按钮
  43. ttk.Button(control_frame, text="选择工作目录 (F)", command=self.select_work_directory).pack(pady=5)
  44. # 分类按钮
  45. ttk.Button(control_frame, text="无人机 (1)", command=lambda: self.classify_image("drone")).pack(pady=5)
  46. ttk.Button(control_frame, text="鸟类 (2)", command=lambda: self.classify_image("bird")).pack(pady=5)
  47. ttk.Button(control_frame, text="有人机 (3)", command=lambda: self.classify_image("manned")).pack(pady=5)
  48. ttk.Button(control_frame, text="其他 (4)", command=lambda: self.classify_image("other")).pack(pady=5)
  49. ttk.Button(control_frame, text="无目标 (5)", command=lambda: self.classify_image("none")).pack(pady=5)
  50. # 导航按钮
  51. nav_frame = ttk.Frame(control_frame)
  52. nav_frame.pack(pady=10)
  53. ttk.Button(nav_frame, text="上一张 (←)", command=self.prev_image).pack(side=tk.LEFT, padx=5)
  54. ttk.Button(nav_frame, text="下一张 (→)", command=self.next_image).pack(side=tk.LEFT, padx=5)
  55. # 撤回按钮
  56. ttk.Button(control_frame, text="撤回操作 (6)", command=self.undo_operation).pack(pady=5)
  57. # 图片处理控制
  58. process_frame = ttk.LabelFrame(control_frame, text="图片处理")
  59. process_frame.pack(pady=10, fill=tk.X)
  60. # 缩放控制
  61. ttk.Label(process_frame, text="缩放:").pack()
  62. self.zoom_scale = ttk.Scale(process_frame, from_=0.1, to=3.0, orient=tk.HORIZONTAL,
  63. command=self.update_zoom)
  64. self.zoom_scale.set(1.0)
  65. self.zoom_scale.pack(fill=tk.X, padx=5)
  66. # 旋转控制
  67. ttk.Label(process_frame, text="旋转:").pack()
  68. self.rotation_scale = ttk.Scale(process_frame, from_=0, to=360, orient=tk.HORIZONTAL,
  69. command=self.update_rotation)
  70. self.rotation_scale.set(0)
  71. self.rotation_scale.pack(fill=tk.X, padx=5)
  72. # 亮度控制
  73. ttk.Label(process_frame, text="亮度:").pack()
  74. self.brightness_scale = ttk.Scale(process_frame, from_=0.1, to=2.0, orient=tk.HORIZONTAL,
  75. command=self.update_brightness)
  76. self.brightness_scale.set(1.0)
  77. self.brightness_scale.pack(fill=tk.X, padx=5)
  78. # 重置按钮
  79. ttk.Button(process_frame, text="重置图片 (R)", command=self.reset_image).pack(pady=5)
  80. # 右侧显示区域
  81. display_frame = ttk.Frame(main_frame)
  82. display_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
  83. # 上方:标注图片显示
  84. original_frame = ttk.LabelFrame(display_frame, text="标注图片(带检测框)")
  85. original_frame.pack(fill=tk.BOTH, expand=True, pady=(0, 5))
  86. self.original_label = ttk.Label(original_frame)
  87. self.original_label.pack(fill=tk.BOTH, expand=True)
  88. # 下方:目标切片显示(仅作参考)
  89. targets_frame = ttk.LabelFrame(display_frame, text="检测到的目标(仅供参考)")
  90. targets_frame.pack(fill=tk.X, pady=(5, 0))
  91. # 创建可滚动的目标显示区域
  92. targets_canvas = tk.Canvas(targets_frame, height=200)
  93. targets_scrollbar = ttk.Scrollbar(targets_frame, orient="horizontal", command=targets_canvas.xview)
  94. self.targets_scroll_frame = ttk.Frame(targets_canvas)
  95. self.targets_scroll_frame.bind(
  96. "<Configure>",
  97. lambda e: targets_canvas.configure(scrollregion=targets_canvas.bbox("all"))
  98. )
  99. targets_canvas.create_window((0, 0), window=self.targets_scroll_frame, anchor="nw")
  100. targets_canvas.configure(xscrollcommand=targets_scrollbar.set)
  101. targets_canvas.pack(side="top", fill="both", expand=True)
  102. targets_scrollbar.pack(side="bottom", fill="x")
  103. # 状态标签
  104. self.status_label = ttk.Label(self.root, text="请选择包含original、targets、annotated文件夹的工作目录(显示标注图片,分类原图)")
  105. self.status_label.pack(side=tk.BOTTOM, pady=5)
  106. # 图片信息标签
  107. self.info_label = ttk.Label(self.root, text="")
  108. self.info_label.pack(side=tk.BOTTOM, pady=5)
  109. def setup_keyboard_shortcuts(self):
  110. self.root.bind('<F5>', lambda e: self.select_work_directory())
  111. self.root.bind('<Left>', lambda e: self.prev_image())
  112. self.root.bind('<Right>', lambda e: self.next_image())
  113. self.root.bind('1', lambda e: self.classify_image("drone"))
  114. self.root.bind('2', lambda e: self.classify_image("bird"))
  115. self.root.bind('3', lambda e: self.classify_image("manned"))
  116. self.root.bind('4', lambda e: self.classify_image("other"))
  117. self.root.bind('5', lambda e: self.classify_image("none"))
  118. self.root.bind('6', lambda e: self.undo_operation())
  119. self.root.bind('r', lambda e: self.reset_image())
  120. def update_zoom(self, value):
  121. self.zoom_factor = float(value)
  122. self.show_current_image()
  123. def update_rotation(self, value):
  124. self.rotation_angle = float(value)
  125. self.show_current_image()
  126. def update_brightness(self, value):
  127. self.brightness_factor = float(value)
  128. self.show_current_image()
  129. def reset_image(self):
  130. self.zoom_factor = 1.0
  131. self.rotation_angle = 0
  132. self.brightness_factor = 1.0
  133. self.zoom_scale.set(1.0)
  134. self.rotation_scale.set(0)
  135. self.brightness_scale.set(1.0)
  136. self.show_current_image()
  137. def select_work_directory(self):
  138. """选择包含original、targets、annotated三个文件夹的工作目录"""
  139. folder_path = filedialog.askdirectory(title="选择包含original、targets、annotated文件夹的工作目录")
  140. if folder_path:
  141. # 检查是否包含必要的子文件夹
  142. original_path = os.path.join(folder_path, 'original')
  143. targets_path = os.path.join(folder_path, 'targets')
  144. annotated_path = os.path.join(folder_path, 'annotated')
  145. if not all(os.path.exists(path) for path in [original_path, targets_path, annotated_path]):
  146. messagebox.showerror("错误", "所选目录必须包含original、targets、annotated三个文件夹")
  147. return
  148. self.work_directory = folder_path
  149. self.original_dir = original_path
  150. self.targets_dir = targets_path
  151. self.annotated_dir = annotated_path
  152. self.load_images()
  153. def load_images(self):
  154. """加载original文件夹中的图片"""
  155. self.image_list = []
  156. if not self.original_dir:
  157. return
  158. for file in os.listdir(self.original_dir):
  159. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
  160. self.image_list.append(os.path.join(self.original_dir, file))
  161. if self.image_list:
  162. self.current_index = 0
  163. self.show_current_image()
  164. self.status_label.config(text=f"已加载 {len(self.image_list)} 张图片")
  165. else:
  166. messagebox.showinfo("提示", "original文件夹中没有图片")
  167. def find_target_images(self, original_filename):
  168. """查找与原图对应的目标切片"""
  169. if not self.targets_dir:
  170. return []
  171. base_name = os.path.splitext(original_filename)[0]
  172. target_files = []
  173. for file in os.listdir(self.targets_dir):
  174. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
  175. if file.startswith(base_name + '_'):
  176. target_files.append(os.path.join(self.targets_dir, file))
  177. return sorted(target_files)
  178. def find_annotated_image(self, original_filename):
  179. """查找与原图对应的标注图片(带检测框的图片)"""
  180. if not self.annotated_dir:
  181. return None
  182. base_name = os.path.splitext(original_filename)[0]
  183. # 尝试多种可能的文件扩展名
  184. for ext in ['.png', '.jpg', '.jpeg', '.bmp', '.gif']:
  185. annotated_path = os.path.join(self.annotated_dir, base_name + ext)
  186. if os.path.exists(annotated_path):
  187. return annotated_path
  188. return None
  189. def show_current_image(self):
  190. """显示当前图片和对应的目标切片"""
  191. if 0 <= self.current_index < len(self.image_list):
  192. self.current_original_path = self.image_list[self.current_index]
  193. # 显示标注图片(带检测框)
  194. self.display_original_image()
  195. # 显示目标切片
  196. original_filename = os.path.basename(self.current_original_path)
  197. self.current_targets = self.find_target_images(original_filename)
  198. self.display_target_images()
  199. # 更新状态信息
  200. self.status_label.config(text=f"图片 {self.current_index + 1}/{len(self.image_list)} | 检测到 {len(self.current_targets)} 个目标")
  201. # 显示图片信息
  202. if os.path.exists(self.current_original_path):
  203. file_size = os.path.getsize(self.current_original_path) / 1024
  204. image = Image.open(self.current_original_path)
  205. width, height = image.size
  206. self.info_label.config(text=f"文件名: {original_filename} | 尺寸: {width}x{height} | 大小: {file_size:.1f}KB")
  207. def display_original_image(self):
  208. """显示标注图片(带检测框),如果没有标注图片则显示原图"""
  209. if not self.current_original_path or not os.path.exists(self.current_original_path):
  210. self.original_label.config(image='', text="图片不存在")
  211. return
  212. try:
  213. # 首先尝试查找标注图片
  214. original_filename = os.path.basename(self.current_original_path)
  215. annotated_path = self.find_annotated_image(original_filename)
  216. # 如果有标注图片就显示标注图片,否则显示原图
  217. if annotated_path and os.path.exists(annotated_path):
  218. image = Image.open(annotated_path)
  219. display_text = "(显示标注图片)"
  220. else:
  221. image = Image.open(self.current_original_path)
  222. display_text = "(无标注图片,显示原图)"
  223. # 应用亮度调整
  224. enhancer = ImageEnhance.Brightness(image)
  225. image = enhancer.enhance(self.brightness_factor)
  226. # 应用旋转
  227. image = image.rotate(self.rotation_angle, expand=True)
  228. # 计算调整后的尺寸,保持宽高比
  229. width, height = image.size
  230. max_size = (600, 400)
  231. ratio = min(max_size[0]/width, max_size[1]/height)
  232. new_size = (int(width*ratio*self.zoom_factor), int(height*ratio*self.zoom_factor))
  233. image = image.resize(new_size, Image.Resampling.LANCZOS)
  234. # 转换为PhotoImage
  235. photo = ImageTk.PhotoImage(image)
  236. # 更新图片显示
  237. self.original_label.config(image=photo, text="")
  238. self.original_label.image = photo # 保持引用
  239. # 更新框架标题以显示当前显示的图片类型
  240. parent_frame = self.original_label.master
  241. if hasattr(parent_frame, 'config'):
  242. parent_frame.config(text=f"标注图片(带检测框){display_text}")
  243. except Exception as e:
  244. self.original_label.config(image='', text=f"图片加载失败: {str(e)}")
  245. def display_target_images(self):
  246. """显示目标切片"""
  247. # 清空之前的目标显示
  248. for widget in self.targets_scroll_frame.winfo_children():
  249. widget.destroy()
  250. if not self.current_targets:
  251. no_target_label = ttk.Label(self.targets_scroll_frame, text="未检测到目标")
  252. no_target_label.pack(side=tk.LEFT, padx=10)
  253. return
  254. # 显示每个目标切片
  255. for i, target_path in enumerate(self.current_targets):
  256. try:
  257. # 加载目标图片
  258. target_image = Image.open(target_path)
  259. # 调整大小以适应显示
  260. target_image.thumbnail((150, 150), Image.Resampling.LANCZOS)
  261. target_photo = ImageTk.PhotoImage(target_image)
  262. # 创建目标显示框架
  263. target_frame = ttk.Frame(self.targets_scroll_frame)
  264. target_frame.pack(side=tk.LEFT, padx=5, pady=5)
  265. # 显示目标图片
  266. target_label = ttk.Label(target_frame, image=target_photo)
  267. target_label.image = target_photo # 保持引用
  268. target_label.pack()
  269. # 显示目标文件名
  270. filename_label = ttk.Label(target_frame, text=os.path.basename(target_path),
  271. font=('Arial', 8))
  272. filename_label.pack()
  273. except Exception as e:
  274. error_label = ttk.Label(self.targets_scroll_frame, text=f"目标{i+1}加载失败")
  275. error_label.pack(side=tk.LEFT, padx=5)
  276. def next_image(self):
  277. if self.image_list:
  278. self.current_index = (self.current_index + 1) % len(self.image_list)
  279. self.show_current_image()
  280. def prev_image(self):
  281. if self.image_list:
  282. self.current_index = (self.current_index - 1) % len(self.image_list)
  283. self.show_current_image()
  284. def classify_image(self, category):
  285. """分类当前图片(仅移动原图)"""
  286. if not self.current_original_path:
  287. return
  288. # 创建分类文件夹
  289. category_folders = {
  290. "drone": "drone",
  291. "bird": "bird",
  292. "manned": "manned",
  293. "other": "others",
  294. "none": "none"
  295. }
  296. target_folder = os.path.join(self.work_directory, category_folders[category])
  297. os.makedirs(target_folder, exist_ok=True)
  298. # 只移动原图
  299. original_filename = os.path.basename(self.current_original_path)
  300. target_path = os.path.join(target_folder, original_filename)
  301. try:
  302. # 记录操作历史(在移动之前)
  303. operation_record = {
  304. 'action': 'classify',
  305. 'category': category,
  306. 'original_source': self.current_original_path,
  307. 'original_target': target_path,
  308. 'index': self.current_index,
  309. 'original_filename': original_filename
  310. }
  311. # 执行文件移动(仅移动原图)
  312. if os.path.exists(self.current_original_path):
  313. shutil.move(self.current_original_path, target_path)
  314. operation_record['moved'] = True
  315. else:
  316. operation_record['moved'] = False
  317. # 添加到操作历史
  318. self.operation_history.append(operation_record)
  319. # 限制历史记录数量
  320. if len(self.operation_history) > 50:
  321. self.operation_history.pop(0)
  322. # 从列表中移除当前图片
  323. self.image_list.pop(self.current_index)
  324. if self.image_list:
  325. self.current_index = self.current_index % len(self.image_list)
  326. self.show_current_image()
  327. else:
  328. self.original_label.config(image='', text="所有图片已分类完成")
  329. # 清空目标显示
  330. for widget in self.targets_scroll_frame.winfo_children():
  331. widget.destroy()
  332. self.status_label.config(text="所有图片已分类完成")
  333. self.info_label.config(text="")
  334. except Exception as e:
  335. messagebox.showerror("错误", f"分类操作失败:{str(e)}")
  336. def undo_operation(self):
  337. """撤回最后一次分类操作"""
  338. if not self.operation_history:
  339. messagebox.showinfo("提示", "没有可撤回的操作")
  340. return
  341. # 获取最后一次操作
  342. last_operation = self.operation_history.pop()
  343. if last_operation['action'] == 'classify' and last_operation.get('moved', False):
  344. try:
  345. # 将原图移回原位置
  346. if os.path.exists(last_operation['original_target']):
  347. shutil.move(last_operation['original_target'], last_operation['original_source'])
  348. # 重新加载图片列表
  349. self.load_images()
  350. # 尝试定位到撤回的图片
  351. original_path = last_operation['original_source']
  352. if original_path and original_path in self.image_list:
  353. self.current_index = self.image_list.index(original_path)
  354. else:
  355. self.current_index = 0
  356. if self.image_list:
  357. self.show_current_image()
  358. messagebox.showinfo("成功", f"已撤回对 '{last_operation['original_filename']}' 的分类操作")
  359. else:
  360. messagebox.showinfo("提示", "撤回成功,但当前目录中没有图片")
  361. except Exception as e:
  362. messagebox.showerror("错误", f"撤回操作失败:{str(e)}")
  363. # 如果撤回失败,将操作重新加入历史记录
  364. self.operation_history.append(last_operation)
  365. if __name__ == "__main__":
  366. root = tk.Tk()
  367. app = TargetBasedClassifier(root)
  368. root.mainloop()