ocr_db_crnn.cc 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "paddle_api.h" // NOLINT
  15. #include "paddle_place.h"
  16. #include <chrono>
  17. #include "AutoLog/auto_log/lite_autolog.h"
  18. #include "cls_process.h"
  19. #include "crnn_process.h"
  20. #include "db_post_process.h"
  21. using namespace paddle::lite_api; // NOLINT
  22. using namespace std;
  23. // fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
  24. void NeonMeanScale(const float *din, float *dout, int size,
  25. const std::vector<float> mean,
  26. const std::vector<float> scale) {
  27. if (mean.size() != 3 || scale.size() != 3) {
  28. std::cerr << "[ERROR] mean or scale size must equal to 3" << std::endl;
  29. exit(1);
  30. }
  31. float32x4_t vmean0 = vdupq_n_f32(mean[0]);
  32. float32x4_t vmean1 = vdupq_n_f32(mean[1]);
  33. float32x4_t vmean2 = vdupq_n_f32(mean[2]);
  34. float32x4_t vscale0 = vdupq_n_f32(scale[0]);
  35. float32x4_t vscale1 = vdupq_n_f32(scale[1]);
  36. float32x4_t vscale2 = vdupq_n_f32(scale[2]);
  37. float *dout_c0 = dout;
  38. float *dout_c1 = dout + size;
  39. float *dout_c2 = dout + size * 2;
  40. int i = 0;
  41. for (; i < size - 3; i += 4) {
  42. float32x4x3_t vin3 = vld3q_f32(din);
  43. float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
  44. float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
  45. float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
  46. float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
  47. float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
  48. float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
  49. vst1q_f32(dout_c0, vs0);
  50. vst1q_f32(dout_c1, vs1);
  51. vst1q_f32(dout_c2, vs2);
  52. din += 12;
  53. dout_c0 += 4;
  54. dout_c1 += 4;
  55. dout_c2 += 4;
  56. }
  57. for (; i < size; i++) {
  58. *(dout_c0++) = (*(din++) - mean[0]) * scale[0];
  59. *(dout_c1++) = (*(din++) - mean[1]) * scale[1];
  60. *(dout_c2++) = (*(din++) - mean[2]) * scale[2];
  61. }
  62. }
  63. // resize image to a size multiple of 32 which is required by the network
  64. cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
  65. std::vector<float> &ratio_hw) {
  66. int w = img.cols;
  67. int h = img.rows;
  68. float ratio = 1.f;
  69. int max_wh = w >= h ? w : h;
  70. if (max_wh > max_size_len) {
  71. if (h > w) {
  72. ratio = static_cast<float>(max_size_len) / static_cast<float>(h);
  73. } else {
  74. ratio = static_cast<float>(max_size_len) / static_cast<float>(w);
  75. }
  76. }
  77. int resize_h = static_cast<int>(float(h) * ratio);
  78. int resize_w = static_cast<int>(float(w) * ratio);
  79. if (resize_h % 32 == 0)
  80. resize_h = resize_h;
  81. else if (resize_h / 32 < 1 + 1e-5)
  82. resize_h = 32;
  83. else
  84. resize_h = (resize_h / 32 - 1) * 32;
  85. if (resize_w % 32 == 0)
  86. resize_w = resize_w;
  87. else if (resize_w / 32 < 1 + 1e-5)
  88. resize_w = 32;
  89. else
  90. resize_w = (resize_w / 32 - 1) * 32;
  91. cv::Mat resize_img;
  92. cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
  93. ratio_hw.push_back(static_cast<float>(resize_h) / static_cast<float>(h));
  94. ratio_hw.push_back(static_cast<float>(resize_w) / static_cast<float>(w));
  95. return resize_img;
  96. }
  97. cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
  98. const float thresh = 0.9) {
  99. std::vector<float> mean = {0.5f, 0.5f, 0.5f};
  100. std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
  101. cv::Mat srcimg;
  102. img.copyTo(srcimg);
  103. cv::Mat crop_img;
  104. img.copyTo(crop_img);
  105. cv::Mat resize_img;
  106. int index = 0;
  107. float wh_ratio =
  108. static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
  109. resize_img = ClsResizeImg(crop_img);
  110. resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
  111. const float *dimg = reinterpret_cast<const float *>(resize_img.data);
  112. std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
  113. input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
  114. auto *data0 = input_tensor0->mutable_data<float>();
  115. NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
  116. // Run CLS predictor
  117. predictor_cls->Run();
  118. // Get output and run postprocess
  119. std::unique_ptr<const Tensor> softmax_out(
  120. std::move(predictor_cls->GetOutput(0)));
  121. auto *softmax_scores = softmax_out->mutable_data<float>();
  122. auto softmax_out_shape = softmax_out->shape();
  123. float score = 0;
  124. int label = 0;
  125. for (int i = 0; i < softmax_out_shape[1]; i++) {
  126. if (softmax_scores[i] > score) {
  127. score = softmax_scores[i];
  128. label = i;
  129. }
  130. }
  131. if (label % 2 == 1 && score > thresh) {
  132. cv::rotate(srcimg, srcimg, 1);
  133. }
  134. return srcimg;
  135. }
  136. void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
  137. std::shared_ptr<PaddlePredictor> predictor_crnn,
  138. std::vector<std::string> &rec_text,
  139. std::vector<float> &rec_text_score,
  140. std::vector<std::string> charactor_dict,
  141. std::shared_ptr<PaddlePredictor> predictor_cls,
  142. int use_direction_classify, std::vector<double> *times,
  143. int rec_image_height) {
  144. std::vector<float> mean = {0.5f, 0.5f, 0.5f};
  145. std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
  146. cv::Mat srcimg;
  147. img.copyTo(srcimg);
  148. cv::Mat crop_img;
  149. cv::Mat resize_img;
  150. int index = 0;
  151. std::vector<double> time_info = {0, 0, 0};
  152. for (int i = boxes.size() - 1; i >= 0; i--) {
  153. auto preprocess_start = std::chrono::steady_clock::now();
  154. crop_img = GetRotateCropImage(srcimg, boxes[i]);
  155. if (use_direction_classify >= 1) {
  156. crop_img = RunClsModel(crop_img, predictor_cls);
  157. }
  158. float wh_ratio =
  159. static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
  160. resize_img = CrnnResizeImg(crop_img, wh_ratio, rec_image_height);
  161. resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
  162. const float *dimg = reinterpret_cast<const float *>(resize_img.data);
  163. std::unique_ptr<Tensor> input_tensor0(
  164. std::move(predictor_crnn->GetInput(0)));
  165. input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
  166. auto *data0 = input_tensor0->mutable_data<float>();
  167. NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
  168. auto preprocess_end = std::chrono::steady_clock::now();
  169. //// Run CRNN predictor
  170. auto inference_start = std::chrono::steady_clock::now();
  171. predictor_crnn->Run();
  172. // Get output and run postprocess
  173. std::unique_ptr<const Tensor> output_tensor0(
  174. std::move(predictor_crnn->GetOutput(0)));
  175. auto *predict_batch = output_tensor0->data<float>();
  176. auto predict_shape = output_tensor0->shape();
  177. auto inference_end = std::chrono::steady_clock::now();
  178. // ctc decode
  179. auto postprocess_start = std::chrono::steady_clock::now();
  180. std::string str_res;
  181. int argmax_idx;
  182. int last_index = 0;
  183. float score = 0.f;
  184. int count = 0;
  185. float max_value = 0.0f;
  186. for (int n = 0; n < predict_shape[1]; n++) {
  187. argmax_idx = int(Argmax(&predict_batch[n * predict_shape[2]],
  188. &predict_batch[(n + 1) * predict_shape[2]]));
  189. max_value =
  190. float(*std::max_element(&predict_batch[n * predict_shape[2]],
  191. &predict_batch[(n + 1) * predict_shape[2]]));
  192. if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
  193. score += max_value;
  194. count += 1;
  195. str_res += charactor_dict[argmax_idx];
  196. }
  197. last_index = argmax_idx;
  198. }
  199. score /= count;
  200. rec_text.push_back(str_res);
  201. rec_text_score.push_back(score);
  202. auto postprocess_end = std::chrono::steady_clock::now();
  203. std::chrono::duration<float> preprocess_diff =
  204. preprocess_end - preprocess_start;
  205. time_info[0] += double(preprocess_diff.count() * 1000);
  206. std::chrono::duration<float> inference_diff =
  207. inference_end - inference_start;
  208. time_info[1] += double(inference_diff.count() * 1000);
  209. std::chrono::duration<float> postprocess_diff =
  210. postprocess_end - postprocess_start;
  211. time_info[2] += double(postprocess_diff.count() * 1000);
  212. }
  213. times->push_back(time_info[0]);
  214. times->push_back(time_info[1]);
  215. times->push_back(time_info[2]);
  216. }
  217. std::vector<std::vector<std::vector<int>>>
  218. RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
  219. std::map<std::string, double> Config, std::vector<double> *times) {
  220. // Read img
  221. int max_side_len = int(Config["max_side_len"]);
  222. int det_db_use_dilate = int(Config["det_db_use_dilate"]);
  223. cv::Mat srcimg;
  224. img.copyTo(srcimg);
  225. auto preprocess_start = std::chrono::steady_clock::now();
  226. std::vector<float> ratio_hw;
  227. img = DetResizeImg(img, max_side_len, ratio_hw);
  228. cv::Mat img_fp;
  229. img.convertTo(img_fp, CV_32FC3, 1.0 / 255.f);
  230. // Prepare input data from image
  231. std::unique_ptr<Tensor> input_tensor0(std::move(predictor->GetInput(0)));
  232. input_tensor0->Resize({1, 3, img_fp.rows, img_fp.cols});
  233. auto *data0 = input_tensor0->mutable_data<float>();
  234. std::vector<float> mean = {0.485f, 0.456f, 0.406f};
  235. std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
  236. const float *dimg = reinterpret_cast<const float *>(img_fp.data);
  237. NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
  238. auto preprocess_end = std::chrono::steady_clock::now();
  239. // Run predictor
  240. auto inference_start = std::chrono::steady_clock::now();
  241. predictor->Run();
  242. // Get output and post process
  243. std::unique_ptr<const Tensor> output_tensor(
  244. std::move(predictor->GetOutput(0)));
  245. auto *outptr = output_tensor->data<float>();
  246. auto shape_out = output_tensor->shape();
  247. auto inference_end = std::chrono::steady_clock::now();
  248. // Save output
  249. auto postprocess_start = std::chrono::steady_clock::now();
  250. float pred[shape_out[2] * shape_out[3]];
  251. unsigned char cbuf[shape_out[2] * shape_out[3]];
  252. for (int i = 0; i < int(shape_out[2] * shape_out[3]); i++) {
  253. pred[i] = static_cast<float>(outptr[i]);
  254. cbuf[i] = static_cast<unsigned char>((outptr[i]) * 255);
  255. }
  256. cv::Mat cbuf_map(shape_out[2], shape_out[3], CV_8UC1,
  257. reinterpret_cast<unsigned char *>(cbuf));
  258. cv::Mat pred_map(shape_out[2], shape_out[3], CV_32F,
  259. reinterpret_cast<float *>(pred));
  260. const double threshold = double(Config["det_db_thresh"]) * 255;
  261. const double max_value = 255;
  262. cv::Mat bit_map;
  263. cv::threshold(cbuf_map, bit_map, threshold, max_value, cv::THRESH_BINARY);
  264. if (det_db_use_dilate == 1) {
  265. cv::Mat dilation_map;
  266. cv::Mat dila_ele =
  267. cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
  268. cv::dilate(bit_map, dilation_map, dila_ele);
  269. bit_map = dilation_map;
  270. }
  271. auto boxes = BoxesFromBitmap(pred_map, bit_map, Config);
  272. std::vector<std::vector<std::vector<int>>> filter_boxes =
  273. FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);
  274. auto postprocess_end = std::chrono::steady_clock::now();
  275. std::chrono::duration<float> preprocess_diff =
  276. preprocess_end - preprocess_start;
  277. times->push_back(double(preprocess_diff.count() * 1000));
  278. std::chrono::duration<float> inference_diff = inference_end - inference_start;
  279. times->push_back(double(inference_diff.count() * 1000));
  280. std::chrono::duration<float> postprocess_diff =
  281. postprocess_end - postprocess_start;
  282. times->push_back(double(postprocess_diff.count() * 1000));
  283. return filter_boxes;
  284. }
  285. std::shared_ptr<PaddlePredictor> loadModel(std::string model_file,
  286. int num_threads) {
  287. MobileConfig config;
  288. config.set_model_from_file(model_file);
  289. config.set_threads(num_threads);
  290. std::shared_ptr<PaddlePredictor> predictor =
  291. CreatePaddlePredictor<MobileConfig>(config);
  292. return predictor;
  293. }
  294. cv::Mat Visualization(cv::Mat srcimg,
  295. std::vector<std::vector<std::vector<int>>> boxes) {
  296. cv::Point rook_points[boxes.size()][4];
  297. for (int n = 0; n < boxes.size(); n++) {
  298. for (int m = 0; m < boxes[0].size(); m++) {
  299. rook_points[n][m] = cv::Point(static_cast<int>(boxes[n][m][0]),
  300. static_cast<int>(boxes[n][m][1]));
  301. }
  302. }
  303. cv::Mat img_vis;
  304. srcimg.copyTo(img_vis);
  305. for (int n = 0; n < boxes.size(); n++) {
  306. const cv::Point *ppt[1] = {rook_points[n]};
  307. int npt[] = {4};
  308. cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
  309. }
  310. cv::imwrite("./vis.jpg", img_vis);
  311. std::cout << "The detection visualized image saved in ./vis.jpg" << std::endl;
  312. return img_vis;
  313. }
  314. std::vector<std::string> split(const std::string &str,
  315. const std::string &delim) {
  316. std::vector<std::string> res;
  317. if ("" == str)
  318. return res;
  319. char *strs = new char[str.length() + 1];
  320. std::strcpy(strs, str.c_str());
  321. char *d = new char[delim.length() + 1];
  322. std::strcpy(d, delim.c_str());
  323. char *p = std::strtok(strs, d);
  324. while (p) {
  325. string s = p;
  326. res.push_back(s);
  327. p = std::strtok(NULL, d);
  328. }
  329. return res;
  330. }
  331. std::map<std::string, double> LoadConfigTxt(std::string config_path) {
  332. auto config = ReadDict(config_path);
  333. std::map<std::string, double> dict;
  334. for (int i = 0; i < config.size(); i++) {
  335. std::vector<std::string> res = split(config[i], " ");
  336. dict[res[0]] = stod(res[1]);
  337. }
  338. return dict;
  339. }
  340. void check_params(int argc, char **argv) {
  341. if (argc <= 1 ||
  342. (strcmp(argv[1], "det") != 0 && strcmp(argv[1], "rec") != 0 &&
  343. strcmp(argv[1], "system") != 0)) {
  344. std::cerr << "Please choose one mode of [det, rec, system] !" << std::endl;
  345. exit(1);
  346. }
  347. if (strcmp(argv[1], "det") == 0) {
  348. if (argc < 9) {
  349. std::cerr << "[ERROR] usage:" << argv[0]
  350. << " det det_model runtime_device num_threads batchsize "
  351. "img_dir det_config lite_benchmark_value"
  352. << std::endl;
  353. exit(1);
  354. }
  355. }
  356. if (strcmp(argv[1], "rec") == 0) {
  357. if (argc < 9) {
  358. std::cerr << "[ERROR] usage:" << argv[0]
  359. << " rec rec_model runtime_device num_threads batchsize "
  360. "img_dir key_txt lite_benchmark_value"
  361. << std::endl;
  362. exit(1);
  363. }
  364. }
  365. if (strcmp(argv[1], "system") == 0) {
  366. if (argc < 12) {
  367. std::cerr << "[ERROR] usage:" << argv[0]
  368. << " system det_model rec_model clas_model runtime_device "
  369. "num_threads batchsize img_dir det_config key_txt "
  370. "lite_benchmark_value"
  371. << std::endl;
  372. exit(1);
  373. }
  374. }
  375. }
  376. void system(char **argv) {
  377. std::string det_model_file = argv[2];
  378. std::string rec_model_file = argv[3];
  379. std::string cls_model_file = argv[4];
  380. std::string runtime_device = argv[5];
  381. std::string precision = argv[6];
  382. std::string num_threads = argv[7];
  383. std::string batchsize = argv[8];
  384. std::string img_dir = argv[9];
  385. std::string det_config_path = argv[10];
  386. std::string dict_path = argv[11];
  387. if (strcmp(argv[6], "FP32") != 0 && strcmp(argv[6], "INT8") != 0) {
  388. std::cerr << "Only support FP32 or INT8." << std::endl;
  389. exit(1);
  390. }
  391. std::vector<cv::String> cv_all_img_names;
  392. cv::glob(img_dir, cv_all_img_names);
  393. //// load config from txt file
  394. auto Config = LoadConfigTxt(det_config_path);
  395. int use_direction_classify = int(Config["use_direction_classify"]);
  396. int rec_image_height = int(Config["rec_image_height"]);
  397. auto charactor_dict = ReadDict(dict_path);
  398. charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
  399. charactor_dict.push_back(" ");
  400. auto det_predictor = loadModel(det_model_file, std::stoi(num_threads));
  401. auto rec_predictor = loadModel(rec_model_file, std::stoi(num_threads));
  402. auto cls_predictor = loadModel(cls_model_file, std::stoi(num_threads));
  403. std::vector<double> det_time_info = {0, 0, 0};
  404. std::vector<double> rec_time_info = {0, 0, 0};
  405. for (int i = 0; i < cv_all_img_names.size(); ++i) {
  406. std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
  407. cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
  408. if (!srcimg.data) {
  409. std::cerr << "[ERROR] image read failed! image path: "
  410. << cv_all_img_names[i] << std::endl;
  411. exit(1);
  412. }
  413. std::vector<double> det_times;
  414. auto boxes = RunDetModel(det_predictor, srcimg, Config, &det_times);
  415. std::vector<std::string> rec_text;
  416. std::vector<float> rec_text_score;
  417. std::vector<double> rec_times;
  418. RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
  419. charactor_dict, cls_predictor, use_direction_classify,
  420. &rec_times, rec_image_height);
  421. //// visualization
  422. auto img_vis = Visualization(srcimg, boxes);
  423. //// print recognized text
  424. for (int i = 0; i < rec_text.size(); i++) {
  425. std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
  426. << std::endl;
  427. }
  428. det_time_info[0] += det_times[0];
  429. det_time_info[1] += det_times[1];
  430. det_time_info[2] += det_times[2];
  431. rec_time_info[0] += rec_times[0];
  432. rec_time_info[1] += rec_times[1];
  433. rec_time_info[2] += rec_times[2];
  434. }
  435. if (strcmp(argv[12], "True") == 0) {
  436. AutoLogger autolog_det(det_model_file, runtime_device,
  437. std::stoi(num_threads), std::stoi(batchsize),
  438. "dynamic", precision, det_time_info,
  439. cv_all_img_names.size());
  440. AutoLogger autolog_rec(rec_model_file, runtime_device,
  441. std::stoi(num_threads), std::stoi(batchsize),
  442. "dynamic", precision, rec_time_info,
  443. cv_all_img_names.size());
  444. autolog_det.report();
  445. std::cout << std::endl;
  446. autolog_rec.report();
  447. }
  448. }
  449. void det(int argc, char **argv) {
  450. std::string det_model_file = argv[2];
  451. std::string runtime_device = argv[3];
  452. std::string precision = argv[4];
  453. std::string num_threads = argv[5];
  454. std::string batchsize = argv[6];
  455. std::string img_dir = argv[7];
  456. std::string det_config_path = argv[8];
  457. if (strcmp(argv[4], "FP32") != 0 && strcmp(argv[4], "INT8") != 0) {
  458. std::cerr << "Only support FP32 or INT8." << std::endl;
  459. exit(1);
  460. }
  461. std::vector<cv::String> cv_all_img_names;
  462. cv::glob(img_dir, cv_all_img_names);
  463. //// load config from txt file
  464. auto Config = LoadConfigTxt(det_config_path);
  465. auto det_predictor = loadModel(det_model_file, std::stoi(num_threads));
  466. std::vector<double> time_info = {0, 0, 0};
  467. for (int i = 0; i < cv_all_img_names.size(); ++i) {
  468. std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
  469. cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
  470. if (!srcimg.data) {
  471. std::cerr << "[ERROR] image read failed! image path: "
  472. << cv_all_img_names[i] << std::endl;
  473. exit(1);
  474. }
  475. std::vector<double> times;
  476. auto boxes = RunDetModel(det_predictor, srcimg, Config, &times);
  477. //// visualization
  478. auto img_vis = Visualization(srcimg, boxes);
  479. std::cout << boxes.size() << " bboxes have detected:" << std::endl;
  480. for (int i = 0; i < boxes.size(); i++) {
  481. std::cout << "The " << i << " box:" << std::endl;
  482. for (int j = 0; j < 4; j++) {
  483. for (int k = 0; k < 2; k++) {
  484. std::cout << boxes[i][j][k] << "\t";
  485. }
  486. }
  487. std::cout << std::endl;
  488. }
  489. time_info[0] += times[0];
  490. time_info[1] += times[1];
  491. time_info[2] += times[2];
  492. }
  493. if (strcmp(argv[9], "True") == 0) {
  494. AutoLogger autolog(det_model_file, runtime_device, std::stoi(num_threads),
  495. std::stoi(batchsize), "dynamic", precision, time_info,
  496. cv_all_img_names.size());
  497. autolog.report();
  498. }
  499. }
  500. void rec(int argc, char **argv) {
  501. std::string rec_model_file = argv[2];
  502. std::string runtime_device = argv[3];
  503. std::string precision = argv[4];
  504. std::string num_threads = argv[5];
  505. std::string batchsize = argv[6];
  506. std::string img_dir = argv[7];
  507. std::string dict_path = argv[8];
  508. std::string config_path = argv[9];
  509. if (strcmp(argv[4], "FP32") != 0 && strcmp(argv[4], "INT8") != 0) {
  510. std::cerr << "Only support FP32 or INT8." << std::endl;
  511. exit(1);
  512. }
  513. auto Config = LoadConfigTxt(config_path);
  514. int rec_image_height = int(Config["rec_image_height"]);
  515. std::vector<cv::String> cv_all_img_names;
  516. cv::glob(img_dir, cv_all_img_names);
  517. auto charactor_dict = ReadDict(dict_path);
  518. charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
  519. charactor_dict.push_back(" ");
  520. auto rec_predictor = loadModel(rec_model_file, std::stoi(num_threads));
  521. std::shared_ptr<PaddlePredictor> cls_predictor;
  522. std::vector<double> time_info = {0, 0, 0};
  523. for (int i = 0; i < cv_all_img_names.size(); ++i) {
  524. std::cout << "The predict img: " << cv_all_img_names[i] << std::endl;
  525. cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
  526. if (!srcimg.data) {
  527. std::cerr << "[ERROR] image read failed! image path: "
  528. << cv_all_img_names[i] << std::endl;
  529. exit(1);
  530. }
  531. int width = srcimg.cols;
  532. int height = srcimg.rows;
  533. std::vector<int> upper_left = {0, 0};
  534. std::vector<int> upper_right = {width, 0};
  535. std::vector<int> lower_right = {width, height};
  536. std::vector<int> lower_left = {0, height};
  537. std::vector<std::vector<int>> box = {upper_left, upper_right, lower_right,
  538. lower_left};
  539. std::vector<std::vector<std::vector<int>>> boxes = {box};
  540. std::vector<std::string> rec_text;
  541. std::vector<float> rec_text_score;
  542. std::vector<double> times;
  543. RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
  544. charactor_dict, cls_predictor, 0, &times, rec_image_height);
  545. //// print recognized text
  546. for (int i = 0; i < rec_text.size(); i++) {
  547. std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
  548. << std::endl;
  549. }
  550. time_info[0] += times[0];
  551. time_info[1] += times[1];
  552. time_info[2] += times[2];
  553. }
  554. // TODO: support autolog
  555. if (strcmp(argv[9], "True") == 0) {
  556. AutoLogger autolog(rec_model_file, runtime_device, std::stoi(num_threads),
  557. std::stoi(batchsize), "dynamic", precision, time_info,
  558. cv_all_img_names.size());
  559. autolog.report();
  560. }
  561. }
  562. int main(int argc, char **argv) {
  563. check_params(argc, argv);
  564. std::cout << "mode: " << argv[1] << endl;
  565. if (strcmp(argv[1], "system") == 0) {
  566. system(argv);
  567. }
  568. if (strcmp(argv[1], "det") == 0) {
  569. det(argc, argv);
  570. }
  571. if (strcmp(argv[1], "rec") == 0) {
  572. rec(argc, argv);
  573. }
  574. return 0;
  575. }