cli.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "src/api/models/doc_img_orientation_classification.h"
  15. #include "src/api/models/text_detection.h"
  16. #include "src/api/models/text_image_unwarping.h"
  17. #include "src/api/models/text_recognition.h"
  18. #include "src/api/models/textline_orientation_classification.h"
  19. #include "src/api/pipelines/doc_preprocessor.h"
  20. #include "src/api/pipelines/ocr.h"
  21. #include "src/utils/args.h"
  22. #include <functional>
  23. #include <iostream>
  24. #include <memory>
  25. #include <string>
  26. #include <tuple>
  27. #include <unordered_map>
  28. #include <vector>
  29. static const std::unordered_set<std::string> SUPPORT_MODE_PIPELINE = {
  30. "ocr",
  31. "doc_preprocessor",
  32. };
  33. static const std::unordered_set<std::string> SUPPORT_MODE_MODEL = {
  34. "text_image_unwarping", "doc_img_orientation_classification",
  35. "textline_orientation_classification", "text_detection",
  36. "text_recognition"};
  37. void PrintErrorInfo(const std::string &msg, const std::string &main_mode = "") {
  38. auto join_modes =
  39. [](const std::unordered_set<std::string> &modes) -> std::string {
  40. std::string result;
  41. for (const auto &mode : modes) {
  42. result += mode + ", ";
  43. }
  44. if (!result.empty()) {
  45. result.pop_back();
  46. result.pop_back();
  47. }
  48. return result;
  49. };
  50. std::string pipeline_modes = join_modes(SUPPORT_MODE_PIPELINE);
  51. std::string model_modes = join_modes(SUPPORT_MODE_MODEL);
  52. INFOE("%s%s", msg.c_str(),
  53. main_mode.empty() ? "" : (": \"" + main_mode + "\"").c_str());
  54. INFO("==========================================");
  55. INFO("Supported pipeline : [%s]", pipeline_modes.c_str());
  56. INFO("Supported model : [%s]", model_modes.c_str());
  57. INFO("==========================================");
  58. }
  59. std::tuple<PaddleOCRParams, DocPreprocessorParams,
  60. DocImgOrientationClassificationParams, TextImageUnwarpingParams,
  61. TextDetectionParams, TextLineOrientationClassificationParams,
  62. TextRecognitionParams>
  63. GetPipelineMoudleParams() {
  64. PaddleOCRParams ocr_params;
  65. DocPreprocessorParams doc_pre_params;
  66. DocImgOrientationClassificationParams doc_orient_params;
  67. TextImageUnwarpingParams unwarp_params;
  68. TextDetectionParams det_params;
  69. TextLineOrientationClassificationParams teline_orient_params;
  70. TextRecognitionParams rec_params;
  71. if (!FLAGS_doc_orientation_classify_model_name.empty()) {
  72. ocr_params.doc_orientation_classify_model_name =
  73. FLAGS_doc_orientation_classify_model_name;
  74. doc_pre_params.doc_orientation_classify_model_name =
  75. FLAGS_doc_orientation_classify_model_name;
  76. doc_orient_params.model_name = FLAGS_doc_orientation_classify_model_name;
  77. }
  78. if (!FLAGS_doc_orientation_classify_model_dir.empty()) {
  79. ocr_params.doc_orientation_classify_model_dir =
  80. FLAGS_doc_orientation_classify_model_dir;
  81. doc_pre_params.doc_orientation_classify_model_dir =
  82. FLAGS_doc_orientation_classify_model_dir;
  83. doc_orient_params.model_dir = FLAGS_doc_orientation_classify_model_dir;
  84. }
  85. if (!FLAGS_doc_unwarping_model_name.empty()) {
  86. ocr_params.doc_unwarping_model_name = FLAGS_doc_unwarping_model_name;
  87. doc_pre_params.doc_unwarping_model_name = FLAGS_doc_unwarping_model_name;
  88. unwarp_params.model_name = FLAGS_doc_unwarping_model_name;
  89. }
  90. if (!FLAGS_doc_unwarping_model_dir.empty()) {
  91. ocr_params.doc_unwarping_model_dir = FLAGS_doc_unwarping_model_dir;
  92. doc_pre_params.doc_unwarping_model_dir = FLAGS_doc_unwarping_model_dir;
  93. unwarp_params.model_dir = FLAGS_doc_unwarping_model_dir;
  94. }
  95. if (!FLAGS_text_detection_model_name.empty()) {
  96. ocr_params.text_detection_model_name = FLAGS_text_detection_model_name;
  97. det_params.model_name = FLAGS_text_detection_model_name;
  98. }
  99. if (!FLAGS_text_detection_model_dir.empty()) {
  100. ocr_params.text_detection_model_dir = FLAGS_text_detection_model_dir;
  101. det_params.model_dir = FLAGS_text_detection_model_dir;
  102. }
  103. if (!FLAGS_textline_orientation_model_name.empty()) {
  104. ocr_params.textline_orientation_model_name =
  105. FLAGS_textline_orientation_model_name;
  106. teline_orient_params.model_name = FLAGS_textline_orientation_model_name;
  107. }
  108. if (!FLAGS_textline_orientation_model_dir.empty()) {
  109. ocr_params.textline_orientation_model_dir =
  110. FLAGS_textline_orientation_model_dir;
  111. teline_orient_params.model_dir = FLAGS_textline_orientation_model_dir;
  112. }
  113. if (!FLAGS_textline_orientation_batch_size.empty()) {
  114. ocr_params.textline_orientation_batch_size =
  115. std::stoi(FLAGS_textline_orientation_batch_size);
  116. }
  117. if (!FLAGS_text_recognition_model_name.empty()) {
  118. ocr_params.text_recognition_model_name = FLAGS_text_recognition_model_name;
  119. rec_params.model_name = FLAGS_text_recognition_model_name;
  120. }
  121. if (!FLAGS_text_recognition_model_dir.empty()) {
  122. ocr_params.text_recognition_model_dir = FLAGS_text_recognition_model_dir;
  123. rec_params.model_dir = FLAGS_text_recognition_model_dir;
  124. }
  125. if (!FLAGS_text_recognition_batch_size.empty()) {
  126. ocr_params.text_recognition_batch_size =
  127. std::stoi(FLAGS_text_recognition_batch_size);
  128. rec_params.batch_size = std::stoi(FLAGS_text_recognition_batch_size);
  129. rec_params.input_shape =
  130. YamlConfig::SmartParseVector(FLAGS_text_rec_input_shape).vec_int;
  131. }
  132. if (!FLAGS_use_doc_orientation_classify.empty()) {
  133. ocr_params.use_doc_orientation_classify =
  134. Utility::StringToBool(FLAGS_use_doc_orientation_classify);
  135. doc_pre_params.use_doc_orientation_classify =
  136. Utility::StringToBool(FLAGS_use_doc_orientation_classify);
  137. }
  138. if (!FLAGS_use_doc_unwarping.empty()) {
  139. ocr_params.use_doc_unwarping =
  140. Utility::StringToBool(FLAGS_use_doc_unwarping);
  141. doc_pre_params.use_doc_unwarping =
  142. Utility::StringToBool(FLAGS_use_doc_unwarping);
  143. }
  144. if (!FLAGS_use_textline_orientation.empty()) {
  145. ocr_params.use_textline_orientation =
  146. Utility::StringToBool(FLAGS_use_textline_orientation);
  147. }
  148. if (!FLAGS_text_det_limit_side_len.empty()) {
  149. ocr_params.text_det_limit_side_len =
  150. std::stoi(FLAGS_text_det_limit_side_len);
  151. }
  152. if (!FLAGS_text_det_limit_type.empty()) {
  153. ocr_params.text_det_limit_type = FLAGS_text_det_limit_type;
  154. det_params.limit_type = FLAGS_text_det_limit_type;
  155. }
  156. if (!FLAGS_text_det_thresh.empty()) {
  157. ocr_params.text_det_thresh = std::stof(FLAGS_text_det_thresh);
  158. det_params.thresh = std::stof(FLAGS_text_det_thresh);
  159. }
  160. if (!FLAGS_text_det_box_thresh.empty()) {
  161. ocr_params.text_det_box_thresh = std::stof(FLAGS_text_det_box_thresh);
  162. det_params.box_thresh = std::stof(FLAGS_text_det_box_thresh);
  163. }
  164. if (!FLAGS_text_det_unclip_ratio.empty()) {
  165. ocr_params.text_det_unclip_ratio = std::stof(FLAGS_text_det_unclip_ratio);
  166. det_params.unclip_ratio = std::stof(FLAGS_text_det_unclip_ratio);
  167. }
  168. if (!FLAGS_text_det_input_shape.empty()) {
  169. ocr_params.text_det_input_shape =
  170. YamlConfig::SmartParseVector(FLAGS_text_det_input_shape).vec_int;
  171. det_params.input_shape =
  172. YamlConfig::SmartParseVector(FLAGS_text_det_input_shape).vec_int;
  173. }
  174. if (!FLAGS_text_rec_score_thresh.empty()) {
  175. ocr_params.text_rec_score_thresh = std::stof(FLAGS_text_rec_score_thresh);
  176. }
  177. if (!FLAGS_text_rec_input_shape.empty()) {
  178. ocr_params.text_rec_input_shape =
  179. YamlConfig::SmartParseVector(FLAGS_text_rec_input_shape).vec_int;
  180. }
  181. if (!FLAGS_lang.empty()) {
  182. ocr_params.lang = FLAGS_lang;
  183. }
  184. if (!FLAGS_ocr_version.empty()) {
  185. ocr_params.ocr_version = FLAGS_ocr_version;
  186. }
  187. if (!FLAGS_vis_font_dir.empty()) {
  188. ocr_params.vis_font_dir = FLAGS_vis_font_dir;
  189. rec_params.vis_font_dir = FLAGS_vis_font_dir;
  190. }
  191. if (!FLAGS_device.empty()) {
  192. ocr_params.device = FLAGS_device;
  193. doc_pre_params.device = FLAGS_device;
  194. doc_orient_params.device = FLAGS_device;
  195. unwarp_params.device = FLAGS_device;
  196. teline_orient_params.device = FLAGS_device;
  197. det_params.device = FLAGS_device;
  198. rec_params.device = FLAGS_device;
  199. }
  200. if (!FLAGS_precision.empty()) {
  201. ocr_params.precision = FLAGS_precision;
  202. doc_pre_params.precision = FLAGS_precision;
  203. doc_orient_params.precision = FLAGS_precision;
  204. unwarp_params.precision = FLAGS_precision;
  205. teline_orient_params.precision = FLAGS_precision;
  206. det_params.precision = FLAGS_precision;
  207. rec_params.precision = FLAGS_precision;
  208. }
  209. if (!FLAGS_enable_mkldnn.empty()) {
  210. ocr_params.enable_mkldnn = Utility::StringToBool(FLAGS_enable_mkldnn);
  211. doc_pre_params.enable_mkldnn = Utility::StringToBool(FLAGS_enable_mkldnn);
  212. doc_orient_params.enable_mkldnn =
  213. Utility::StringToBool(FLAGS_enable_mkldnn);
  214. unwarp_params.enable_mkldnn = Utility::StringToBool(FLAGS_enable_mkldnn);
  215. teline_orient_params.enable_mkldnn =
  216. Utility::StringToBool(FLAGS_enable_mkldnn);
  217. det_params.enable_mkldnn = Utility::StringToBool(FLAGS_enable_mkldnn);
  218. rec_params.enable_mkldnn = Utility::StringToBool(FLAGS_enable_mkldnn);
  219. }
  220. if (!FLAGS_mkldnn_cache_capacity.empty()) {
  221. ocr_params.mkldnn_cache_capacity = std::stoi(FLAGS_mkldnn_cache_capacity);
  222. doc_pre_params.mkldnn_cache_capacity =
  223. std::stoi(FLAGS_mkldnn_cache_capacity);
  224. doc_orient_params.mkldnn_cache_capacity =
  225. std::stoi(FLAGS_mkldnn_cache_capacity);
  226. unwarp_params.mkldnn_cache_capacity =
  227. std::stoi(FLAGS_mkldnn_cache_capacity);
  228. teline_orient_params.mkldnn_cache_capacity =
  229. std::stoi(FLAGS_mkldnn_cache_capacity);
  230. det_params.mkldnn_cache_capacity = std::stoi(FLAGS_mkldnn_cache_capacity);
  231. rec_params.mkldnn_cache_capacity = std::stoi(FLAGS_mkldnn_cache_capacity);
  232. }
  233. if (!FLAGS_cpu_threads.empty()) {
  234. ocr_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  235. doc_pre_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  236. doc_orient_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  237. unwarp_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  238. teline_orient_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  239. det_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  240. rec_params.cpu_threads = std::stoi(FLAGS_cpu_threads);
  241. }
  242. if (!FLAGS_thread_num.empty()) {
  243. ocr_params.thread_num = std::stoi(FLAGS_thread_num);
  244. doc_pre_params.thread_num = std::stoi(FLAGS_thread_num);
  245. }
  246. if (!FLAGS_paddlex_config.empty()) {
  247. ocr_params.paddlex_config = FLAGS_paddlex_config;
  248. doc_pre_params.paddlex_config = FLAGS_paddlex_config;
  249. }
  250. return std::make_tuple(ocr_params, doc_pre_params, doc_orient_params,
  251. unwarp_params, det_params, teline_orient_params,
  252. rec_params);
  253. }
  254. int main(int argc, char *argv[]) {
  255. gflags::ParseCommandLineFlags(&argc, &argv, true);
  256. if (FLAGS_input.empty()) {
  257. INFOE("Require input, such as ./build/ppocr <pipeline_or_module> --input "
  258. "your_image_path [--param1] [--param2] [...]");
  259. exit(-1);
  260. }
  261. std::string main_mode = "";
  262. if (argc > 1) {
  263. main_mode = argv[1];
  264. if (SUPPORT_MODE_PIPELINE.count(main_mode) == 0 &&
  265. SUPPORT_MODE_MODEL.count(main_mode) == 0) {
  266. PrintErrorInfo("ERROR: Unsupported pipeline or module", main_mode);
  267. exit(-1);
  268. }
  269. } else {
  270. PrintErrorInfo(
  271. "Must provide pipeline or module name, such as ./build/ppocr "
  272. "<pipeline_or_module> [--param1] [--param2] [...]");
  273. exit(-1);
  274. }
  275. auto params = GetPipelineMoudleParams();
  276. using PredFunc = std::function<std::vector<std::unique_ptr<BaseCVResult>>(
  277. const std::string &)>;
  278. std::unordered_map<std::string, PredFunc> pred_map = {
  279. {"ocr",
  280. [&params](const std::string &input) {
  281. return PaddleOCR(std::get<0>(params)).Predict(input);
  282. }},
  283. {"doc_preprocessor",
  284. [&params](const std::string &input) {
  285. return DocPreprocessor(std::get<1>(params)).Predict(input);
  286. }},
  287. {"doc_img_orientation_classification",
  288. [&params](const std::string &input) {
  289. return DocImgOrientationClassification(std::get<2>(params))
  290. .Predict(input);
  291. }},
  292. {"text_image_unwarping",
  293. [&params](const std::string &input) {
  294. return TextImageUnwarping(std::get<3>(params)).Predict(input);
  295. }},
  296. {"text_detection",
  297. [&params](const std::string &input) {
  298. return TextDetection(std::get<4>(params)).Predict(input);
  299. }},
  300. {"textline_orientation_classification",
  301. [&params](const std::string &input) {
  302. return TextLineOrientationClassification(std::get<5>(params))
  303. .Predict(input);
  304. }},
  305. {"text_recognition",
  306. [&params](const std::string &input) {
  307. return TextRecognition(std::get<6>(params)).Predict(input);
  308. }},
  309. };
  310. auto it = pred_map.find(main_mode);
  311. auto outputs = it->second(FLAGS_input);
  312. for (auto &output : outputs) {
  313. output->Print();
  314. output->SaveToImg(FLAGS_save_path);
  315. output->SaveToJson(FLAGS_save_path);
  316. }
  317. return 0;
  318. }