test_formula_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import sys
  2. import os
  3. from pathlib import Path
  4. from typing import Any
  5. import paddle
  6. import pytest
  7. current_dir = os.path.dirname(os.path.abspath(__file__))
  8. sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
  9. from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModel, DonutSwinModelOutput
  10. from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4_Formula
  11. from ppocr.modeling.backbones.rec_vary_vit import Vary_VIT_B_Formula
  12. from ppocr.modeling.heads.rec_unimernet_head import UniMERNetHead
  13. from ppocr.modeling.heads.rec_ppformulanet_head import PPFormulaNet_Head
  14. @pytest.fixture
  15. def sample_image():
  16. return paddle.randn([1, 1, 192, 672])
  17. @pytest.fixture
  18. def sample_image_ppformulanet_s():
  19. return paddle.randn([1, 1, 384, 384])
  20. @pytest.fixture
  21. def sample_image_ppformulanet_l():
  22. return paddle.randn([1, 1, 768, 768])
  23. @pytest.fixture
  24. def encoder_feat():
  25. encoded_feat = paddle.randn([1, 126, 1024])
  26. return DonutSwinModelOutput(
  27. last_hidden_state=encoded_feat,
  28. )
  29. @pytest.fixture
  30. def encoder_feat_ppformulanet_s():
  31. encoded_feat = paddle.randn([1, 144, 2048])
  32. return DonutSwinModelOutput(
  33. last_hidden_state=encoded_feat,
  34. )
  35. @pytest.fixture
  36. def encoder_feat_ppformulanet_l():
  37. encoded_feat = paddle.randn([1, 144, 1024])
  38. return DonutSwinModelOutput(
  39. last_hidden_state=encoded_feat,
  40. )
  41. def test_unimernet_backbone(sample_image):
  42. """
  43. Test UniMERNet backbone.
  44. Args:
  45. sample_image: sample image to be processed.
  46. """
  47. backbone = DonutSwinModel(
  48. hidden_size=1024,
  49. num_layers=4,
  50. num_heads=[4, 8, 16, 32],
  51. add_pooling_layer=True,
  52. use_mask_token=False,
  53. )
  54. backbone.eval()
  55. with paddle.no_grad():
  56. result = backbone(sample_image)
  57. encoder_feat = result[0]
  58. assert encoder_feat.shape == [1, 126, 1024]
  59. def test_unimernet_head(encoder_feat):
  60. """
  61. Test UniMERNet head.
  62. Args:
  63. encoder_feat: encoder feature from unimernet backbone.
  64. """
  65. head = UniMERNetHead(
  66. max_new_tokens=5,
  67. decoder_start_token_id=0,
  68. temperature=0.2,
  69. do_sample=False,
  70. top_p=0.95,
  71. encoder_hidden_size=1024,
  72. is_export=False,
  73. length_aware=True,
  74. )
  75. head.eval()
  76. with paddle.no_grad():
  77. result = head(encoder_feat)
  78. assert result.shape == [1, 6]
  79. def test_ppformulanet_s_backbone(sample_image_ppformulanet_s):
  80. """
  81. Test PP-FormulaNet-S backbone.
  82. Args:
  83. sample_image_ppformulanet_s: sample image to be processed.
  84. """
  85. backbone = PPHGNetV2_B4_Formula(
  86. class_num=1024,
  87. )
  88. backbone.eval()
  89. with paddle.no_grad():
  90. result = backbone(sample_image_ppformulanet_s)
  91. encoder_feat = result[0]
  92. assert encoder_feat.shape == [1, 144, 2048]
  93. def test_ppformulanet_s_head(encoder_feat_ppformulanet_s):
  94. """
  95. Test PP-FormulaNet-S head.
  96. Args:
  97. encoder_feat_ppformulanet_s: encoder feature from PP-FormulaNet-S backbone.
  98. """
  99. head = PPFormulaNet_Head(
  100. max_new_tokens=6,
  101. decoder_start_token_id=0,
  102. decoder_ffn_dim=1536,
  103. decoder_hidden_size=384,
  104. decoder_layers=2,
  105. temperature=0.2,
  106. do_sample=False,
  107. top_p=0.95,
  108. encoder_hidden_size=2048,
  109. is_export=False,
  110. length_aware=True,
  111. use_parallel=True,
  112. parallel_step=3,
  113. )
  114. head.eval()
  115. with paddle.no_grad():
  116. result = head(encoder_feat_ppformulanet_s)
  117. assert result.shape == [1, 9]
  118. def test_ppformulanet_l_backbone(sample_image_ppformulanet_l):
  119. """
  120. Test PP-FormulaNet-L backbone.
  121. Args:
  122. sample_image_ppformulanet_l: sample image to be processed.
  123. """
  124. backbone = Vary_VIT_B_Formula(
  125. image_size=768,
  126. encoder_embed_dim=768,
  127. encoder_depth=12,
  128. encoder_num_heads=12,
  129. encoder_global_attn_indexes=[2, 5, 8, 11],
  130. )
  131. backbone.eval()
  132. with paddle.no_grad():
  133. result = backbone(sample_image_ppformulanet_l)
  134. encoder_feat = result[0]
  135. assert encoder_feat.shape == [1, 144, 1024]
  136. def test_ppformulanet_l_head(encoder_feat_ppformulanet_l):
  137. """
  138. Test PP-FormulaNet-L head.
  139. Args:
  140. encoder_feat_ppformulanet_l: encoder feature from PP-FormulaNet-L Head.
  141. """
  142. head = PPFormulaNet_Head(
  143. max_new_tokens=6,
  144. decoder_start_token_id=0,
  145. decoder_ffn_dim=2048,
  146. decoder_hidden_size=512,
  147. decoder_layers=8,
  148. temperature=0.2,
  149. do_sample=False,
  150. top_p=0.95,
  151. encoder_hidden_size=1024,
  152. is_export=False,
  153. length_aware=False,
  154. use_parallel=False,
  155. parallel_step=0,
  156. )
  157. head.eval()
  158. with paddle.no_grad():
  159. result = head(encoder_feat_ppformulanet_l)
  160. assert result.shape == [1, 7]