vqa_layoutlm.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. from paddle import nn
  19. from paddlenlp.transformers import (
  20. LayoutXLMModel,
  21. LayoutXLMForTokenClassification,
  22. LayoutXLMForRelationExtraction,
  23. )
  24. from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
  25. from paddlenlp.transformers import (
  26. LayoutLMv2Model,
  27. LayoutLMv2ForTokenClassification,
  28. LayoutLMv2ForRelationExtraction,
  29. )
  30. from paddlenlp.transformers import AutoModel
  31. __all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
  32. pretrained_model_dict = {
  33. LayoutXLMModel: {
  34. "base": "layoutxlm-base-uncased",
  35. "vi": "vi-layoutxlm-base-uncased",
  36. },
  37. LayoutLMModel: {
  38. "base": "layoutlm-base-uncased",
  39. },
  40. LayoutLMv2Model: {
  41. "base": "layoutlmv2-base-uncased",
  42. "vi": "vi-layoutlmv2-base-uncased",
  43. },
  44. }
  45. class NLPBaseModel(nn.Layer):
  46. def __init__(
  47. self,
  48. base_model_class,
  49. model_class,
  50. mode="base",
  51. type="ser",
  52. pretrained=True,
  53. checkpoints=None,
  54. **kwargs,
  55. ):
  56. super(NLPBaseModel, self).__init__()
  57. if checkpoints is not None: # load the trained model
  58. self.model = model_class.from_pretrained(checkpoints)
  59. else: # load the pretrained-model
  60. pretrained_model_name = pretrained_model_dict[base_model_class][mode]
  61. if type == "ser":
  62. self.model = model_class.from_pretrained(
  63. pretrained_model_name, num_classes=kwargs["num_classes"], dropout=0
  64. )
  65. else:
  66. self.model = model_class.from_pretrained(
  67. pretrained_model_name, dropout=0
  68. )
  69. self.out_channels = 1
  70. self.use_visual_backbone = True
  71. class LayoutLMForSer(NLPBaseModel):
  72. def __init__(
  73. self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
  74. ):
  75. super(LayoutLMForSer, self).__init__(
  76. LayoutLMModel,
  77. LayoutLMForTokenClassification,
  78. mode,
  79. "ser",
  80. pretrained,
  81. checkpoints,
  82. num_classes=num_classes,
  83. )
  84. self.use_visual_backbone = False
  85. def forward(self, x):
  86. x = self.model(
  87. input_ids=x[0],
  88. bbox=x[1],
  89. attention_mask=x[2],
  90. token_type_ids=x[3],
  91. position_ids=None,
  92. output_hidden_states=False,
  93. )
  94. return x
  95. class LayoutLMv2ForSer(NLPBaseModel):
  96. def __init__(
  97. self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
  98. ):
  99. super(LayoutLMv2ForSer, self).__init__(
  100. LayoutLMv2Model,
  101. LayoutLMv2ForTokenClassification,
  102. mode,
  103. "ser",
  104. pretrained,
  105. checkpoints,
  106. num_classes=num_classes,
  107. )
  108. if (
  109. hasattr(self.model.layoutlmv2, "use_visual_backbone")
  110. and self.model.layoutlmv2.use_visual_backbone is False
  111. ):
  112. self.use_visual_backbone = False
  113. def forward(self, x):
  114. if self.use_visual_backbone is True:
  115. image = x[4]
  116. else:
  117. image = None
  118. x = self.model(
  119. input_ids=x[0],
  120. bbox=x[1],
  121. attention_mask=x[2],
  122. token_type_ids=x[3],
  123. image=image,
  124. position_ids=None,
  125. head_mask=None,
  126. labels=None,
  127. )
  128. if self.training:
  129. res = {"backbone_out": x[0]}
  130. res.update(x[1])
  131. return res
  132. else:
  133. return x
  134. class LayoutXLMForSer(NLPBaseModel):
  135. def __init__(
  136. self, num_classes, pretrained=True, checkpoints=None, mode="base", **kwargs
  137. ):
  138. super(LayoutXLMForSer, self).__init__(
  139. LayoutXLMModel,
  140. LayoutXLMForTokenClassification,
  141. mode,
  142. "ser",
  143. pretrained,
  144. checkpoints,
  145. num_classes=num_classes,
  146. )
  147. if (
  148. hasattr(self.model.layoutxlm, "use_visual_backbone")
  149. and self.model.layoutxlm.use_visual_backbone is False
  150. ):
  151. self.use_visual_backbone = False
  152. def forward(self, x):
  153. if self.use_visual_backbone is True:
  154. image = x[4]
  155. else:
  156. image = None
  157. x = self.model(
  158. input_ids=x[0],
  159. bbox=x[1],
  160. attention_mask=x[2],
  161. token_type_ids=x[3],
  162. image=image,
  163. position_ids=None,
  164. head_mask=None,
  165. labels=None,
  166. )
  167. if self.training:
  168. res = {"backbone_out": x[0]}
  169. res.update(x[1])
  170. return res
  171. else:
  172. return x
  173. class LayoutLMv2ForRe(NLPBaseModel):
  174. def __init__(self, pretrained=True, checkpoints=None, mode="base", **kwargs):
  175. super(LayoutLMv2ForRe, self).__init__(
  176. LayoutLMv2Model,
  177. LayoutLMv2ForRelationExtraction,
  178. mode,
  179. "re",
  180. pretrained,
  181. checkpoints,
  182. )
  183. if (
  184. hasattr(self.model.layoutlmv2, "use_visual_backbone")
  185. and self.model.layoutlmv2.use_visual_backbone is False
  186. ):
  187. self.use_visual_backbone = False
  188. def forward(self, x):
  189. x = self.model(
  190. input_ids=x[0],
  191. bbox=x[1],
  192. attention_mask=x[2],
  193. token_type_ids=x[3],
  194. image=x[4],
  195. position_ids=None,
  196. head_mask=None,
  197. labels=None,
  198. entities=x[5],
  199. relations=x[6],
  200. )
  201. return x
  202. class LayoutXLMForRe(NLPBaseModel):
  203. def __init__(self, pretrained=True, checkpoints=None, mode="base", **kwargs):
  204. super(LayoutXLMForRe, self).__init__(
  205. LayoutXLMModel,
  206. LayoutXLMForRelationExtraction,
  207. mode,
  208. "re",
  209. pretrained,
  210. checkpoints,
  211. )
  212. if (
  213. hasattr(self.model.layoutxlm, "use_visual_backbone")
  214. and self.model.layoutxlm.use_visual_backbone is False
  215. ):
  216. self.use_visual_backbone = False
  217. def forward(self, x):
  218. if self.use_visual_backbone is True:
  219. image = x[4]
  220. entities = x[5]
  221. relations = x[6]
  222. else:
  223. image = None
  224. entities = x[4]
  225. relations = x[5]
  226. x = self.model(
  227. input_ids=x[0],
  228. bbox=x[1],
  229. attention_mask=x[2],
  230. token_type_ids=x[3],
  231. image=image,
  232. position_ids=None,
  233. head_mask=None,
  234. labels=None,
  235. entities=entities,
  236. relations=relations,
  237. )
  238. return x