postprocess.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. // Copyright (c) 2021 by Rockchip Electronics Co., Ltd. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "postprocess.h"
  15. #include <math.h>
  16. #include <stdint.h>
  17. #include <stdio.h>
  18. #include <stdlib.h>
  19. #include <string.h>
  20. #include <sys/time.h>
  21. #include <set>
  22. #include <vector>
  23. #define LABEL_NALE_TXT_PATH "./coco_80_labels_list.txt"
  24. static char *labels[OBJ_CLASS_NUM];
  25. const int anchor0[6] = {10, 13, 16, 30, 33, 23};
  26. const int anchor1[6] = {30, 61, 62, 45, 59, 119};
  27. const int anchor2[6] = {116, 90, 156, 198, 373, 326};
  28. inline static int clamp(float val, int min, int max) { return val > min ? (val < max ? val : max) : min; }
  29. char *readLine(FILE *fp, char *buffer, int *len)
  30. {
  31. int ch;
  32. int i = 0;
  33. size_t buff_len = 0;
  34. buffer = (char *)malloc(buff_len + 1);
  35. if (!buffer)
  36. return NULL; // Out of memory
  37. while ((ch = fgetc(fp)) != '\n' && ch != EOF)
  38. {
  39. buff_len++;
  40. void *tmp = realloc(buffer, buff_len + 1);
  41. if (tmp == NULL)
  42. {
  43. free(buffer);
  44. return NULL; // Out of memory
  45. }
  46. buffer = (char *)tmp;
  47. buffer[i] = (char)ch;
  48. i++;
  49. }
  50. buffer[i] = '\0';
  51. *len = buff_len;
  52. // Detect end
  53. if (ch == EOF && (i == 0 || ferror(fp)))
  54. {
  55. free(buffer);
  56. return NULL;
  57. }
  58. return buffer;
  59. }
  60. int readLines(const char *fileName, char *lines[], int max_line)
  61. {
  62. FILE *file = fopen(fileName, "r");
  63. char *s;
  64. int i = 0;
  65. int n = 0;
  66. if (file == NULL)
  67. {
  68. printf("Open %s fail!\n", fileName);
  69. return -1;
  70. }
  71. while ((s = readLine(file, s, &n)) != NULL)
  72. {
  73. lines[i++] = s;
  74. if (i >= max_line)
  75. break;
  76. }
  77. fclose(file);
  78. return i;
  79. }
  80. int loadLabelName(const char *locationFilename, char *label[])
  81. {
  82. printf("loadLabelName %s\n", locationFilename);
  83. readLines(locationFilename, label, OBJ_CLASS_NUM);
  84. return 0;
  85. }
  86. static float CalculateOverlap(float xmin0, float ymin0, float xmax0, float ymax0, float xmin1, float ymin1, float xmax1,
  87. float ymax1)
  88. {
  89. float w = fmax(0.f, fmin(xmax0, xmax1) - fmax(xmin0, xmin1) + 1.0);
  90. float h = fmax(0.f, fmin(ymax0, ymax1) - fmax(ymin0, ymin1) + 1.0);
  91. float i = w * h;
  92. float u = (xmax0 - xmin0 + 1.0) * (ymax0 - ymin0 + 1.0) + (xmax1 - xmin1 + 1.0) * (ymax1 - ymin1 + 1.0) - i;
  93. return u <= 0.f ? 0.f : (i / u);
  94. }
  95. static int nms(int validCount, std::vector<float> &outputLocations, std::vector<int> classIds, std::vector<int> &order,
  96. int filterId, float threshold)
  97. {
  98. for (int i = 0; i < validCount; ++i)
  99. {
  100. if (order[i] == -1 || classIds[i] != filterId)
  101. {
  102. continue;
  103. }
  104. int n = order[i];
  105. for (int j = i + 1; j < validCount; ++j)
  106. {
  107. int m = order[j];
  108. if (m == -1 || classIds[i] != filterId)
  109. {
  110. continue;
  111. }
  112. float xmin0 = outputLocations[n * 4 + 0];
  113. float ymin0 = outputLocations[n * 4 + 1];
  114. float xmax0 = outputLocations[n * 4 + 0] + outputLocations[n * 4 + 2];
  115. float ymax0 = outputLocations[n * 4 + 1] + outputLocations[n * 4 + 3];
  116. float xmin1 = outputLocations[m * 4 + 0];
  117. float ymin1 = outputLocations[m * 4 + 1];
  118. float xmax1 = outputLocations[m * 4 + 0] + outputLocations[m * 4 + 2];
  119. float ymax1 = outputLocations[m * 4 + 1] + outputLocations[m * 4 + 3];
  120. float iou = CalculateOverlap(xmin0, ymin0, xmax0, ymax0, xmin1, ymin1, xmax1, ymax1);
  121. if (iou > threshold)
  122. {
  123. order[j] = -1;
  124. }
  125. }
  126. }
  127. return 0;
  128. }
  129. static int quick_sort_indice_inverse(std::vector<float> &input, int left, int right, std::vector<int> &indices)
  130. {
  131. float key;
  132. int key_index;
  133. int low = left;
  134. int high = right;
  135. if (left < right)
  136. {
  137. key_index = indices[left];
  138. key = input[left];
  139. while (low < high)
  140. {
  141. while (low < high && input[high] <= key)
  142. {
  143. high--;
  144. }
  145. input[low] = input[high];
  146. indices[low] = indices[high];
  147. while (low < high && input[low] >= key)
  148. {
  149. low++;
  150. }
  151. input[high] = input[low];
  152. indices[high] = indices[low];
  153. }
  154. input[low] = key;
  155. indices[low] = key_index;
  156. quick_sort_indice_inverse(input, left, low - 1, indices);
  157. quick_sort_indice_inverse(input, low + 1, right, indices);
  158. }
  159. return low;
  160. }
  161. static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
  162. static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); }
  163. inline static int32_t __clip(float val, float min, float max)
  164. {
  165. float f = val <= min ? min : (val >= max ? max : val);
  166. return f;
  167. }
  168. static int8_t qnt_f32_to_affine(float f32, int32_t zp, float scale)
  169. {
  170. float dst_val = (f32 / scale) + zp;
  171. int8_t res = (int8_t)__clip(dst_val, -128, 127);
  172. return res;
  173. }
  174. static float deqnt_affine_to_f32(int8_t qnt, int32_t zp, float scale) { return ((float)qnt - (float)zp) * scale; }
  175. static int process(int8_t *input, int *anchor, int grid_h, int grid_w, int height, int width, int stride,
  176. std::vector<float> &boxes, std::vector<float> &objProbs, std::vector<int> &classId, float threshold,
  177. int32_t zp, float scale)
  178. {
  179. int validCount = 0;
  180. int grid_len = grid_h * grid_w;
  181. int8_t thres_i8 = qnt_f32_to_affine(threshold, zp, scale);
  182. for (int a = 0; a < 3; a++)
  183. {
  184. for (int i = 0; i < grid_h; i++)
  185. {
  186. for (int j = 0; j < grid_w; j++)
  187. {
  188. int8_t box_confidence = input[(PROP_BOX_SIZE * a + 4) * grid_len + i * grid_w + j];
  189. if (box_confidence >= thres_i8)
  190. {
  191. int offset = (PROP_BOX_SIZE * a) * grid_len + i * grid_w + j;
  192. int8_t *in_ptr = input + offset;
  193. float box_x = (deqnt_affine_to_f32(*in_ptr, zp, scale)) * 2.0 - 0.5;
  194. float box_y = (deqnt_affine_to_f32(in_ptr[grid_len], zp, scale)) * 2.0 - 0.5;
  195. float box_w = (deqnt_affine_to_f32(in_ptr[2 * grid_len], zp, scale)) * 2.0;
  196. float box_h = (deqnt_affine_to_f32(in_ptr[3 * grid_len], zp, scale)) * 2.0;
  197. box_x = (box_x + j) * (float)stride;
  198. box_y = (box_y + i) * (float)stride;
  199. box_w = box_w * box_w * (float)anchor[a * 2];
  200. box_h = box_h * box_h * (float)anchor[a * 2 + 1];
  201. box_x -= (box_w / 2.0);
  202. box_y -= (box_h / 2.0);
  203. int8_t maxClassProbs = in_ptr[5 * grid_len];
  204. int maxClassId = 0;
  205. for (int k = 1; k < OBJ_CLASS_NUM; ++k)
  206. {
  207. int8_t prob = in_ptr[(5 + k) * grid_len];
  208. if (prob > maxClassProbs)
  209. {
  210. maxClassId = k;
  211. maxClassProbs = prob;
  212. }
  213. }
  214. if (maxClassProbs > thres_i8)
  215. {
  216. objProbs.push_back((deqnt_affine_to_f32(maxClassProbs, zp, scale)) * (deqnt_affine_to_f32(box_confidence, zp, scale)));
  217. classId.push_back(maxClassId);
  218. validCount++;
  219. boxes.push_back(box_x);
  220. boxes.push_back(box_y);
  221. boxes.push_back(box_w);
  222. boxes.push_back(box_h);
  223. }
  224. }
  225. }
  226. }
  227. }
  228. return validCount;
  229. }
  230. int post_process(int8_t *input0, int8_t *input1, int8_t *input2, int model_in_h, int model_in_w, float conf_threshold,
  231. float nms_threshold, float scale_w, float scale_h, std::vector<int32_t> &qnt_zps,
  232. std::vector<float> &qnt_scales, detect_result_group_t *group)
  233. {
  234. static int init = -1;
  235. if (init == -1)
  236. {
  237. int ret = 0;
  238. ret = loadLabelName(LABEL_NALE_TXT_PATH, labels);
  239. if (ret < 0)
  240. {
  241. return -1;
  242. }
  243. init = 0;
  244. }
  245. memset(group, 0, sizeof(detect_result_group_t));
  246. std::vector<float> filterBoxes;
  247. std::vector<float> objProbs;
  248. std::vector<int> classId;
  249. // stride 8
  250. int stride0 = 8;
  251. int grid_h0 = model_in_h / stride0;
  252. int grid_w0 = model_in_w / stride0;
  253. int validCount0 = 0;
  254. validCount0 = process(input0, (int *)anchor0, grid_h0, grid_w0, model_in_h, model_in_w, stride0, filterBoxes, objProbs,
  255. classId, conf_threshold, qnt_zps[0], qnt_scales[0]);
  256. // stride 16
  257. int stride1 = 16;
  258. int grid_h1 = model_in_h / stride1;
  259. int grid_w1 = model_in_w / stride1;
  260. int validCount1 = 0;
  261. validCount1 = process(input1, (int *)anchor1, grid_h1, grid_w1, model_in_h, model_in_w, stride1, filterBoxes, objProbs,
  262. classId, conf_threshold, qnt_zps[1], qnt_scales[1]);
  263. // stride 32
  264. int stride2 = 32;
  265. int grid_h2 = model_in_h / stride2;
  266. int grid_w2 = model_in_w / stride2;
  267. int validCount2 = 0;
  268. validCount2 = process(input2, (int *)anchor2, grid_h2, grid_w2, model_in_h, model_in_w, stride2, filterBoxes, objProbs,
  269. classId, conf_threshold, qnt_zps[2], qnt_scales[2]);
  270. int validCount = validCount0 + validCount1 + validCount2;
  271. // no object detect
  272. if (validCount <= 0)
  273. {
  274. return 0;
  275. }
  276. std::vector<int> indexArray;
  277. for (int i = 0; i < validCount; ++i)
  278. {
  279. indexArray.push_back(i);
  280. }
  281. quick_sort_indice_inverse(objProbs, 0, validCount - 1, indexArray);
  282. std::set<int> class_set(std::begin(classId), std::end(classId));
  283. for (auto c : class_set)
  284. {
  285. nms(validCount, filterBoxes, classId, indexArray, c, nms_threshold);
  286. }
  287. int last_count = 0;
  288. group->count = 0;
  289. /* box valid detect target */
  290. for (int i = 0; i < validCount; ++i)
  291. {
  292. if (indexArray[i] == -1 || last_count >= OBJ_NUMB_MAX_SIZE)
  293. {
  294. continue;
  295. }
  296. int n = indexArray[i];
  297. float x1 = filterBoxes[n * 4 + 0];
  298. float y1 = filterBoxes[n * 4 + 1];
  299. float x2 = x1 + filterBoxes[n * 4 + 2];
  300. float y2 = y1 + filterBoxes[n * 4 + 3];
  301. int id = classId[n];
  302. float obj_conf = objProbs[i];
  303. group->results[last_count].box.left = (int)(clamp(x1, 0, model_in_w) / scale_w);
  304. group->results[last_count].box.top = (int)(clamp(y1, 0, model_in_h) / scale_h);
  305. group->results[last_count].box.right = (int)(clamp(x2, 0, model_in_w) / scale_w);
  306. group->results[last_count].box.bottom = (int)(clamp(y2, 0, model_in_h) / scale_h);
  307. group->results[last_count].prop = obj_conf;
  308. char *label = labels[id];
  309. strncpy(group->results[last_count].name, label, OBJ_NAME_MAX_SIZE);
  310. // printf("result %2d: (%4d, %4d, %4d, %4d), %s\n", i, group->results[last_count].box.left,
  311. // group->results[last_count].box.top,
  312. // group->results[last_count].box.right, group->results[last_count].box.bottom, label);
  313. last_count++;
  314. }
  315. group->count = last_count;
  316. return 0;
  317. }
  318. void deinitPostProcess()
  319. {
  320. for (int i = 0; i < OBJ_CLASS_NUM; i++)
  321. {
  322. if (labels[i] != nullptr)
  323. {
  324. free(labels[i]);
  325. labels[i] = nullptr;
  326. }
  327. }
  328. }