UniMERNet.yaml 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. Global:
  2. model_name: UniMERNet # To use static model for inference.
  3. use_gpu: True
  4. epoch_num: 40
  5. log_smooth_window: 10
  6. print_batch_step: 10
  7. save_model_dir: ./output/rec/unimernet/
  8. save_epoch_step: 5
  9. # evaluation is run every 37880 iterations after the 0th iteration
  10. eval_batch_step: [0, 37880]
  11. cal_metric_during_train: True
  12. pretrained_model:
  13. checkpoints:
  14. save_inference_dir:
  15. use_visualdl: False
  16. infer_img: doc/datasets/pme_demo/0000013.png
  17. infer_mode: False
  18. use_space_char: False
  19. rec_char_dict_path: &rec_char_dict_path ppocr/utils/dict/unimernet_tokenizer
  20. input_size: &input_size [192, 672]
  21. max_seq_len: &max_seq_len 1024
  22. save_res_path: ./output/rec/predicts_unimernet.txt
  23. allow_resize_largeImg: False
  24. d2s_train_image_shape: [1,192,672]
  25. Optimizer:
  26. name: AdamW
  27. beta1: 0.9
  28. beta2: 0.999
  29. weight_decay: 0.05
  30. lr:
  31. name: LinearWarmupCosine
  32. learning_rate: 1e-4
  33. start_lr: 1e-5
  34. min_lr: 1e-8
  35. warmup_steps: 5000
  36. Architecture:
  37. model_type: rec
  38. algorithm: UniMERNet
  39. in_channels: 3
  40. Transform:
  41. Backbone:
  42. name: DonutSwinModel
  43. hidden_size : 1024
  44. num_layers: 4
  45. num_heads: [4, 8, 16, 32]
  46. add_pooling_layer: True
  47. use_mask_token: False
  48. Head:
  49. name: UniMERNetHead
  50. max_new_tokens: 1536
  51. decoder_start_token_id: 0
  52. temperature: 0.2
  53. do_sample: False
  54. top_p: 0.95
  55. encoder_hidden_size: 1024
  56. is_export: False
  57. length_aware: True
  58. Loss:
  59. name: UniMERNetLoss
  60. PostProcess:
  61. name: UniMERNetDecode
  62. rec_char_dict_path: *rec_char_dict_path
  63. Metric:
  64. name: LaTeXOCRMetric
  65. main_indicator: exp_rate
  66. cal_bleu_score: True
  67. Train:
  68. dataset:
  69. name: SimpleDataSet
  70. data_dir: ./train_data/UniMERNet/
  71. label_file_list: ["./train_data/UniMERNet/train_unimernet_1M.txt"]
  72. transforms:
  73. - UniMERNetImgDecode:
  74. input_size: *input_size
  75. - UniMERNetTrainTransform:
  76. - UniMERNetImageFormat:
  77. - UniMERNetLabelEncode:
  78. rec_char_dict_path: *rec_char_dict_path
  79. max_seq_len: *max_seq_len
  80. - KeepKeys:
  81. keep_keys: ['image', 'label', 'attention_mask']
  82. loader:
  83. shuffle: False
  84. drop_last: False
  85. batch_size_per_card: 7
  86. num_workers: 0
  87. collate_fn: UniMERNetCollator
  88. Eval:
  89. dataset:
  90. name: SimpleDataSet
  91. data_dir: ./train_data/UniMERNet/UniMER-Test/cpe
  92. label_file_list: ["./train_data/UniMERNet/test_unimernet_cpe.txt"]
  93. transforms:
  94. - UniMERNetImgDecode:
  95. input_size: *input_size
  96. - UniMERNetTestTransform:
  97. - UniMERNetImageFormat:
  98. - UniMERNetLabelEncode:
  99. max_seq_len: *max_seq_len
  100. rec_char_dict_path: *rec_char_dict_path
  101. - KeepKeys:
  102. keep_keys: ['image', 'label', 'attention_mask']
  103. loader:
  104. shuffle: False
  105. drop_last: False
  106. batch_size_per_card: 30
  107. num_workers: 0
  108. collate_fn: UniMERNetCollator