feature_extractor.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import cv2
  2. import os
  3. import pickle
  4. from glob import glob
  5. class FalsePositiveFeatureExtractor:
  6. def __init__(self, feature_type='ORB', max_features=1000):
  7. self.feature_type = feature_type
  8. self.max_features = max_features
  9. # 调整ORB参数
  10. self.detector = cv2.ORB_create(
  11. nfeatures=max_features,
  12. scaleFactor=1.2, # 金字塔缩放因子
  13. edgeThreshold=15, # 边缘阈值
  14. patchSize=31 # 特征点邻域大小
  15. )
  16. def extract_features(self, img_path):
  17. # 检查文件是否存在
  18. if not os.path.exists(img_path):
  19. print(f"警告:文件不存在 {img_path}")
  20. return None
  21. # 读取图片并检查有效性
  22. img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
  23. if img is None:
  24. print(f"警告:无法读取图像 {img_path}")
  25. return None
  26. # 检查图片尺寸和内容
  27. if img.size == 0:
  28. print(f"警告:空图像 {img_path}")
  29. return None
  30. # 计算图像清晰度(拉普拉斯方差)
  31. blur_value = cv2.Laplacian(img, cv2.CV_64F).var()
  32. if blur_value < 50: # 阈值可根据实际情况调整
  33. print(f"警告:图像模糊 {img_path} (清晰度: {blur_value:.2f})")
  34. return None
  35. # 提取特征
  36. kp, des = self.detector.detectAndCompute(img, None)
  37. # 检查特征数量和质量
  38. if des is None or len(des) < 10:
  39. print(f"警告:特征不足 {img_path} (特征数: {len(kp) if kp else 0})")
  40. return None
  41. return des
  42. def build_feature_db(self, input_dir, output_file='false_positive_features.pkl'):
  43. all_features = {}
  44. for img_file in glob(os.path.join(input_dir, '*.jpg')) + glob(os.path.join(input_dir, '*.png')):
  45. features = self.extract_features(img_file)
  46. if features is not None and len(features) > 10: # 过滤无效特征
  47. all_features[os.path.basename(img_file)] = {
  48. 'features': features,
  49. 'image_size': cv2.imread(img_file).shape[:2]
  50. }
  51. with open(output_file, 'wb') as f:
  52. pickle.dump(all_features, f)
  53. return len(all_features)
  54. if __name__ == '__main__':
  55. import argparse
  56. parser = argparse.ArgumentParser()
  57. parser.add_argument('--input', type=str, default='false_positive',
  58. help='误报图像目录路径')
  59. parser.add_argument('--output', type=str, default='false_positive_features.pkl',
  60. help='输出特征文件路径')
  61. args = parser.parse_args()
  62. extractor = FalsePositiveFeatureExtractor()
  63. count = extractor.build_feature_db(args.input, args.output)
  64. print(f'成功提取{count}个误报样本的特征')