Vision.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. #pragma once
  2. #ifdef ORTTOOLKIT
  3. #define ORTTOOLKIT __declspec(dllexport)
  4. #else
  5. #define ORTTOOLKIT __declspec(dllimport)
  6. #endif
  7. #include <vector>
  8. #include <array>
  9. #include <opencv/cv.h>
  10. #include <onnxruntime/onnxruntime_cxx_api.h>
  11. #include "Type.h"
  12. namespace OrtToolkit {
  13. namespace Vision {
  14. namespace Detect {
  15. /// <summary>
  16. /// 飞桨PP_YOLOE转ONNX模型BatchSize只能为1
  17. /// </summary>
  18. class ORTTOOLKIT PP_YOLOE
  19. {
  20. public:
  21. /// <summary>
  22. /// PP_YOLOE模型初始化
  23. /// </summary>
  24. /// <param name="model_path">模型路径</param>
  25. /// <param name="confThreshold">置信度阈值</param>
  26. /// <param name="ExecutionProvider">ONNX概念下的 执行提供程序 (EP),参数为字符串,默认为CPU,其他:CUDA</param>
  27. PP_YOLOE(std::string model_path, float confThreshold, std::string ExecutionProvider = "CUDA");
  28. /// <summary>
  29. /// 目标检测
  30. /// </summary>
  31. /// <param name="cv_image">输入OpenCV的BGR图像</param>
  32. void Predict(cv::Mat& cv_image, Type::Res_PP_YOLOE& res);
  33. private:
  34. float confThreshold;//置信度阈值
  35. cv::Mat preprocess(cv::Mat srcimg);
  36. const int inpWidth = 640;
  37. const int inpHeight = 640;
  38. std::vector<float> mean_ = { 0.485, 0.456, 0.406 };
  39. std::vector<float> std_ = { 0.229, 0.224, 0.225 };
  40. std::vector<float> input_image_;
  41. std::vector<float> scale_factor = { 1,1 };
  42. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "PP_YOLOE");
  43. Ort::Session* ort_session = nullptr;
  44. Ort::SessionOptions sessionOptions = Ort::SessionOptions();
  45. std::vector<char*> input_names;
  46. std::vector<char*> output_names;
  47. std::vector<std::vector<int64_t>> input_node_dims; // 输入节点维度
  48. std::vector<std::vector<int64_t>> output_node_dims; // 输出节点维度
  49. };
  50. }
  51. namespace FaceDetect {
  52. class ORTTOOLKIT YOLOV7_face
  53. {
  54. public:
  55. YOLOV7_face(float confThreshold, float nmsThreshold, std::string modelpath, std::string ExecutionProvider = "CUDA");
  56. void Predict(cv::Mat& frame, Type::Res_YOLOV7_FACE& generate_boxes);
  57. private:
  58. int inpWidth;
  59. int inpHeight;
  60. int nout;
  61. int num_proposal;
  62. float confThreshold;//置信度阈值
  63. float nmsThreshold;//nms阈值
  64. std::vector<float> mean_ = { 0.5,0.5,0.5 };
  65. std::vector<float> std_ = { 0.5,0.5,0.5 };
  66. std::vector<float> input_image_;//模型输入图片
  67. void nms(std::vector<Type::Type_Iou_YOLOV7FACE>Items, Type::Res_YOLOV7_FACE& Res);//NMS
  68. bool has_postprocess;
  69. std::vector<Ort::AllocatedStringPtr>In_AllocatedStringPtr;
  70. std::vector<Ort::AllocatedStringPtr>Out_AllocatedStringPtr;
  71. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "YOLOV7_face");
  72. Ort::Session* ort_session = nullptr;//OnnxRunTime会话
  73. Ort::SessionOptions sessionOptions = Ort::SessionOptions();
  74. std::vector<char*> input_names;
  75. std::vector<char*> output_names;
  76. std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputs
  77. std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
  78. };
  79. }
  80. namespace FaceAlign {
  81. class ORTTOOLKIT PFLD98
  82. {
  83. public:
  84. PFLD98(std::string ModPath);
  85. void Predict(cv::Mat& Img, std::vector<cv::Point2f>& Density);
  86. ~PFLD98();
  87. private:
  88. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "MCNN");
  89. Ort::SessionOptions sessionOptions = Ort::SessionOptions();
  90. Ort::AllocatorWithDefaultOptions allocator;
  91. const char* input_name;
  92. const char* output_nameA;
  93. const char* output_nameB;
  94. std::vector<int64_t> input_dims;
  95. std::vector<int64_t> output_dims;
  96. Ort::Session* ort_session;
  97. std::vector<const char*> input_names;
  98. std::vector<const char*> output_names;
  99. std::vector<const char*> input_node_names;
  100. std::vector<const char*> output_node_names;
  101. private:
  102. int DensityRow;
  103. int DensityCol;
  104. std::vector<float> input_image_;
  105. };
  106. }
  107. namespace FaceId {
  108. class ORTTOOLKIT Arcface {
  109. public:
  110. Arcface(std::string modelpath, std::string ExecutionProvider = "CUDA");
  111. void Predict(cv::Mat& frame, std::vector<float>& FaceVector);
  112. private:
  113. int inpWidth;
  114. int inpHeight;
  115. int nout;
  116. std::vector<float> input_image_;//模型输入图片
  117. std::vector<float> mean_ = { 0.5, 0.5, 0.5 };
  118. std::vector<float> std_ = { 0.5, 0.5, 0.5 };
  119. std::vector<Ort::AllocatedStringPtr>In_AllocatedStringPtr;
  120. std::vector<Ort::AllocatedStringPtr>Out_AllocatedStringPtr;
  121. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "Arcface");//日志记录
  122. Ort::Session* ort_session = nullptr;//OnnxRunTime会话
  123. Ort::SessionOptions sessionOptions = Ort::SessionOptions();
  124. std::vector<char*> input_names;
  125. std::vector<char*> output_names;
  126. std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputs
  127. std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
  128. };
  129. }
  130. namespace CrowdCounting {
  131. class ORTTOOLKIT MCNN
  132. {
  133. public:
  134. MCNN(std::string ModPath);
  135. void Predict(cv::Mat& Img, std::vector<float>& Density);
  136. ~MCNN();
  137. private:
  138. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "MCNN");
  139. Ort::SessionOptions sessionOptions = Ort::SessionOptions();
  140. Ort::AllocatorWithDefaultOptions allocator;
  141. const char* input_name;
  142. const char* output_name;
  143. std::vector<int64_t> input_dims;
  144. std::vector<int64_t> output_dims;
  145. Ort::Session* ort_session;
  146. std::vector<const char*> input_names;
  147. std::vector<const char*> output_names;
  148. std::vector<const char*> input_node_names;
  149. std::vector<const char*> output_node_names;
  150. std::vector<float> input_image_;
  151. private:
  152. int DensityRow;
  153. int DensityCol;
  154. };
  155. }
  156. }
  157. namespace Util{
  158. void ORTTOOLKIT overflow(float& info, int lower, int upper);
  159. void ORTTOOLKIT BoxOverflow(Type::Base_Type_Box& Box, int lower, int upper);
  160. float ORTTOOLKIT CosineSimilarity(const std::vector<float>& a, const std::vector<float>& b, bool normalized = false);
  161. std::string ORTTOOLKIT create_uuid();
  162. uint16_t floatToHalf(float f);
  163. float halfToFloat(uint16_t h);
  164. template <typename T, typename U>
  165. void ReadBin(U& x, std::string path) {
  166. std::ifstream file(path, std::ios::binary);
  167. if (file) {
  168. file.seekg(0, std::ios::end);
  169. std::streampos size = file.tellg();
  170. file.seekg(0, std::ios::beg);
  171. x.resize(size / sizeof(T));
  172. file.read(reinterpret_cast<char*>(x.data()), size);
  173. file.close();
  174. }
  175. };
  176. }
  177. }