model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import logging
  4. import math
  5. import os
  6. import re
  7. import sys
  8. import json
  9. import torch
  10. import torch.distributed as dist
  11. import torch.nn as nn
  12. from torchvision.ops import roi_align
  13. from modelscope.metainfo import Models
  14. from modelscope.models import TorchModel
  15. from modelscope.models.builder import MODELS
  16. from modelscope.models.multi_modal.vldoc.conv_fpn_trans import FPNTrans
  17. from modelscope.models.multi_modal.vldoc.modeling_layout_roberta import (
  18. LayoutRobertaModel, LayoutRobertaPreTrainedModel)
  19. from modelscope.models.multi_modal.vldoc.transformer_local import (
  20. TransformerDecoder, TransformerDecoderLayer)
  21. from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
  22. from modelscope.utils.logger import get_logger
  23. logger = get_logger()
  24. __all__ = ['VLDocForDocVLEmbedding']
  25. class GeoVLDocModelOutputs(object):
  26. def __init__(
  27. self,
  28. text_features,
  29. text_mm_features,
  30. block_vis_features,
  31. block_vis_mm_features,
  32. image_mm_features,
  33. ):
  34. # [batch size, sequence length, hidden size]
  35. self.text_features = text_features
  36. # [batch size, sequence length, hidden size]
  37. self.text_mm_features = text_mm_features
  38. # [batch size, block num, hidden size]
  39. self.block_vis_features = block_vis_features
  40. # [batch size, block num, hidden size]
  41. self.block_vis_mm_features = block_vis_mm_features
  42. # [batch size, hidden size]
  43. self.image_mm_features = image_mm_features
  44. class GeoVLDocModel(LayoutRobertaPreTrainedModel):
  45. def __init__(self, config, hard_negtive_sampling=False):
  46. super().__init__(config)
  47. self.config = config
  48. self.hard_negtive_sampling = hard_negtive_sampling
  49. if getattr(self.config, 'architectures', None):
  50. if self.config.architectures[0] == 'LayoutRobertaModel':
  51. self.text_encoder = LayoutRobertaModel(config)
  52. else:
  53. self.text_encoder = LayoutRobertaModel(config)
  54. else:
  55. self.text_encoder = LayoutRobertaModel(config)
  56. self.visual_encoder = FPNTrans(
  57. img_size=self.config.image_size, inner_vit=False)
  58. self.pool = nn.AdaptiveAvgPool2d([1, 1])
  59. self.vis_linear = nn.Linear(256, self.config.hidden_size)
  60. cross_modal_text_layer = TransformerDecoderLayer(
  61. self.config.hidden_size,
  62. self.config.num_attention_heads,
  63. self.config.intermediate_size,
  64. self_attn=True)
  65. self.cross_modal_text = TransformerDecoder(cross_modal_text_layer, 1)
  66. cross_modal_visual_layer = TransformerDecoderLayer(
  67. self.config.hidden_size,
  68. self.config.num_attention_heads,
  69. self.config.intermediate_size,
  70. self_attn=True)
  71. self.cross_modal_visual = TransformerDecoder(cross_modal_visual_layer,
  72. 1)
  73. self.init_weights()
  74. def from_pretrained(self, ckpt_path: str):
  75. state_dict = torch.load(ckpt_path, map_location='cpu')
  76. state_dict_new = {}
  77. for k, v in state_dict.items():
  78. k = k.replace('geo_vl_doc_model.', '')
  79. state_dict_new[k] = v
  80. self.load_state_dict(state_dict_new)
  81. def forward(self,
  82. input_ids=None,
  83. image=None,
  84. bbox=None,
  85. bbox_4p_normalized=None,
  86. attention_mask=None,
  87. first_token_idxes=None,
  88. first_token_idxes_mask=None,
  89. token_type_ids=None,
  90. position_ids=None,
  91. head_mask=None,
  92. inputs_embeds=None,
  93. encoder_hidden_states=None,
  94. encoder_attention_mask=None,
  95. past_key_values=None,
  96. use_cache=None,
  97. output_attentions=None,
  98. output_hidden_states=None,
  99. return_dict=None,
  100. **kwargs):
  101. batch_size, seq_len = input_ids.shape
  102. return_dict = (
  103. return_dict
  104. if return_dict is not None else self.config.use_return_dict)
  105. kwargs['line_bbox'] = bbox
  106. # ################ get text representation ################
  107. if self.config.architectures[0] == 'LayoutRobertaModel':
  108. outputs = self.text_encoder(
  109. input_ids,
  110. bbox=bbox_4p_normalized,
  111. attention_mask=attention_mask,
  112. token_type_ids=token_type_ids,
  113. position_ids=position_ids,
  114. head_mask=head_mask,
  115. inputs_embeds=inputs_embeds,
  116. output_attentions=output_attentions,
  117. output_hidden_states=output_hidden_states,
  118. return_dict=return_dict,
  119. **kwargs)
  120. else:
  121. outputs = self.text_encoder(
  122. input_ids,
  123. bbox=bbox_4p_normalized,
  124. attention_mask=attention_mask,
  125. token_type_ids=token_type_ids,
  126. position_ids=position_ids,
  127. head_mask=head_mask,
  128. inputs_embeds=inputs_embeds,
  129. output_attentions=output_attentions,
  130. output_hidden_states=output_hidden_states,
  131. return_dict=return_dict,
  132. **kwargs)
  133. # sequence_output: [batch_size, seq_len, hidden_size]
  134. # pooled_output: [batch_size, hidden_size]
  135. sequence_output, pooled_output = outputs[:2]
  136. # ################ get visual representation ################
  137. _, num_first = first_token_idxes.shape
  138. B_batch_dim = torch.arange(
  139. 0, batch_size,
  140. device=input_ids.device).reshape(batch_size,
  141. 1).expand(batch_size, num_first)
  142. feature_bbox = bbox[B_batch_dim, first_token_idxes]
  143. _, block_num, _ = feature_bbox.shape
  144. visual_out = self.visual_encoder(image)
  145. batch_idxs = torch.arange(
  146. 0, batch_size, device=sequence_output.device).reshape(
  147. batch_size, 1).expand(batch_size, block_num).unsqueeze(-1)
  148. # [batch_size*block_num, 5]
  149. batch_idx_with_bbox = torch.cat(
  150. (batch_idxs, feature_bbox),
  151. 2).reshape(batch_size * block_num,
  152. 5).to(dtype=visual_out['feat_ms'].dtype)
  153. if visual_out['feat_ms'].dtype == torch.float16:
  154. # [batch_size*block_num, 256, 1, 1]
  155. blk_vis_features = roi_align(
  156. visual_out['feat_ms'].to(torch.float32),
  157. batch_idx_with_bbox.to(torch.float32),
  158. 1,
  159. spatial_scale=visual_out['feat_ms'].size(-1) / 1000.0)
  160. blk_vis_features = blk_vis_features.to(
  161. dtype=visual_out['feat_ms'].dtype)
  162. else:
  163. blk_vis_features = roi_align(
  164. visual_out['feat_ms'],
  165. batch_idx_with_bbox.to(torch.float32),
  166. 1,
  167. spatial_scale=visual_out['feat_ms'].size(-1) / 1000.0)
  168. # [batch_size*block_num, 256]
  169. blk_vis_features = blk_vis_features.squeeze(2).squeeze(2).reshape(
  170. batch_size, block_num, 256)
  171. # visual block features:
  172. # blk_vis_features: [batch_size, block_num, hidden_size]
  173. blk_vis_features = self.vis_linear(blk_vis_features)
  174. blk_vis_features = blk_vis_features * first_token_idxes_mask.unsqueeze(
  175. 2)
  176. # [batch_size, 256]
  177. full_img_features = self.pool(
  178. visual_out['feat_ms']).squeeze(2).squeeze(2)
  179. # [batch_size, hidden_size]
  180. full_img_features = self.vis_linear(full_img_features).unsqueeze(1)
  181. # ################ multi-modal fusion ################
  182. # cross attention inputs
  183. vis_inps = torch.cat((full_img_features, blk_vis_features), 1)
  184. glb_feat_attn = torch.ones((batch_size, 1)).to(input_ids.device)
  185. vis_mask = torch.cat((glb_feat_attn, first_token_idxes_mask), 1)
  186. # When we use transformer in torch.nn, the input size is
  187. # [seq_len, batch_size, hidden_size]
  188. # In attention_mask, 1 denotes masked
  189. new_attention_mask = (1 - attention_mask) > 0
  190. new_vis_mask = (1 - vis_mask) > 0
  191. text_mm_feat = self.cross_modal_text(
  192. tgt=sequence_output.transpose(0, 1),
  193. memory=vis_inps.transpose(0, 1),
  194. tgt_key_padding_mask=new_attention_mask,
  195. memory_key_padding_mask=new_vis_mask)
  196. vis_mm_feat = self.cross_modal_visual(
  197. tgt=vis_inps.transpose(0, 1),
  198. memory=sequence_output.transpose(0, 1),
  199. tgt_key_padding_mask=new_vis_mask,
  200. memory_key_padding_mask=new_attention_mask,
  201. )
  202. # [batch_size, seq_len, hidden_size]
  203. text_mm_feat = text_mm_feat.transpose(0, 1)
  204. # [batch_size, 1+block_num, hidden_size]
  205. vis_mm_feat = vis_mm_feat.transpose(0, 1)
  206. # image_mm_features = vis_mm_feat[:, 0, :]
  207. block_vis_mm_features = vis_mm_feat[:, 1:]
  208. return GeoVLDocModelOutputs(
  209. text_features=sequence_output,
  210. text_mm_features=text_mm_feat,
  211. block_vis_features=blk_vis_features,
  212. block_vis_mm_features=block_vis_mm_features,
  213. image_mm_features=vis_mm_feat,
  214. )
  215. @MODELS.register_module(Tasks.document_vl_embedding, module_name=Models.vldoc)
  216. class VLDocForDocVLEmbedding(TorchModel):
  217. """
  218. Generate multi-modal document embeddings in segment-level and token-level.
  219. Args:
  220. model_dir:
  221. the path in model hub, e.g., 'damo/multi-modal_convnext-roberta-base_vldoc-embedding'
  222. """
  223. def __init__(self, model_dir: str, *args, **kwargs):
  224. super().__init__(model_dir=model_dir, *args, **kwargs)
  225. # Initialize the model.
  226. from modelscope.models.multi_modal.vldoc.modeling_layout_roberta import LayoutRobertaConfig
  227. model_cfg_path = os.path.join(model_dir, 'config.json')
  228. logger.info('Loading config file from {}'.format(model_cfg_path))
  229. assert os.path.exists(model_cfg_path)
  230. self.config = LayoutRobertaConfig.from_json_file(model_cfg_path)
  231. self.doc_model = GeoVLDocModel(self.config)
  232. # restore the pretrained weight
  233. model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  234. assert os.path.exists(model_path)
  235. self.doc_model.from_pretrained(model_path)
  236. logger.info('Loading model from {}'.format(model_path))
  237. # Initialize the tokenizer.
  238. from modelscope.models.multi_modal.vldoc.tokenization import VLDocXLMTokenizer
  239. tokenizer_path = os.path.join(model_dir, ModelFile.TOKENIZER_FOLDER)
  240. self.tokenizer = VLDocXLMTokenizer.from_pretrained(tokenizer_path)
  241. # place the model
  242. self.device = 'cuda:{}'.format(int(os.environ.get(
  243. 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu'
  244. if torch.cuda.is_available():
  245. self.doc_model.to(self.device)
  246. logger.info('Use GPU {} for finetuning & inference'.format(
  247. int(os.environ.get('LOCAL_RANK', 0))))
  248. else:
  249. self.doc_model.float()
  250. logger.info('Use CPU for finetuning & inference')
  251. def forward(self,
  252. input_ids=None,
  253. image=None,
  254. bbox=None,
  255. bbox_4p_normalized=None,
  256. attention_mask=None,
  257. first_token_idxes=None,
  258. first_token_idxes_mask=None,
  259. token_type_ids=None,
  260. position_ids=None,
  261. head_mask=None,
  262. inputs_embeds=None,
  263. encoder_hidden_states=None,
  264. encoder_attention_mask=None,
  265. past_key_values=None,
  266. use_cache=None,
  267. output_attentions=None,
  268. output_hidden_states=None,
  269. return_dict=None,
  270. **kwargs):
  271. """
  272. Args:
  273. - input_ids: :math:`(B, T, E)`, the input tokens, where B is the batch size,
  274. T is the max token size, E is the embedding dimension.
  275. - image: :math:`(B, C, H, W)`, normalized images.
  276. - bbox: :math:`(B, T, 4)`, segment boxes denoted by top-left and bottom-right
  277. vertexes whose values are normalized to [0, 1000).
  278. - bbox_4p_normalized: :math:`(B, T, 8)`, word boxes denoted by 4 vertexes, whose
  279. values are normalized to [0, 1).
  280. - attention_mask: :math:`(B, T)`, mask for input tokens, where 0 means masked.
  281. - first_token_idxes: :math:`(B, S)`, indexes of the corresponding first tokens
  282. of all segments, where S is the max segment size.
  283. - first_token_idxes_mask: :math:`(B, S)`, mask for segments, where 0 means masked.
  284. Optional:
  285. - line_rank_id: :math:`(B, T)`, orders of segments.
  286. - line_rank_inner_id: :math:`(B, T)`, BIE-like tags.
  287. To be more specific, please refer to the class `TextLayoutSerializer` in
  288. `modelscope/models/multi_modal/vldoc/processing.py`.
  289. """
  290. vldoc_outputs = self.doc_model(
  291. input_ids=input_ids,
  292. image=image,
  293. bbox=bbox,
  294. bbox_4p_normalized=bbox_4p_normalized,
  295. attention_mask=attention_mask,
  296. first_token_idxes=first_token_idxes,
  297. first_token_idxes_mask=first_token_idxes_mask,
  298. token_type_ids=token_type_ids,
  299. position_ids=position_ids,
  300. head_mask=head_mask,
  301. inputs_embeds=inputs_embeds,
  302. encoder_hidden_states=encoder_hidden_states,
  303. encoder_attention_mask=encoder_attention_mask,
  304. past_key_values=past_key_values,
  305. use_cache=use_cache,
  306. output_attentions=output_attentions,
  307. output_hidden_states=output_hidden_states,
  308. return_dict=return_dict,
  309. **kwargs)
  310. return dict(
  311. img_embedding=vldoc_outputs.image_mm_features,
  312. text_embedding=vldoc_outputs.text_mm_features,
  313. )
  314. def init_pretrained_weight(
  315. model,
  316. pretrained_model_path,
  317. state_dict=None,
  318. cache_dir=None,
  319. init_backbone='roberta',
  320. ):
  321. if state_dict is None:
  322. state_dict = torch.load(pretrained_model_path, map_location='cpu')
  323. old_keys = []
  324. new_keys = []
  325. state_dict_keys = list(state_dict.keys())
  326. if init_backbone == 'roberta':
  327. for i in range(len(state_dict_keys)):
  328. key = state_dict_keys[i]
  329. new_key = None
  330. if key.startswith('roberta.'):
  331. new_key = key.replace('roberta.',
  332. 'geo_vl_doc_model.text_encoder.')
  333. key = copy.deepcopy(new_key)
  334. if new_key:
  335. old_keys.append(state_dict_keys[i])
  336. new_keys.append(new_key)
  337. for old_key, new_key in zip(old_keys, new_keys):
  338. state_dict[new_key] = state_dict.pop(old_key)
  339. missing_keys = []
  340. unexpected_keys = []
  341. error_msgs = []
  342. # copy state_dict so _load_from_state_dict can modify it
  343. metadata = getattr(state_dict, '_metadata', None)
  344. state_dict = state_dict.copy()
  345. if metadata is not None:
  346. state_dict._metadata = metadata
  347. def load(module, prefix=''):
  348. local_metadata = {} if metadata is None else metadata.get(
  349. prefix[:-1], {})
  350. module._load_from_state_dict(state_dict, prefix, local_metadata, True,
  351. missing_keys, unexpected_keys, error_msgs)
  352. for name, child in module._modules.items():
  353. if child is not None:
  354. load(child, prefix + name + '.')
  355. start_prefix = ''
  356. if not hasattr(model, 'geo_vl_doc_model') and any(
  357. s.startswith('geo_vl_doc_model.') for s in state_dict.keys()):
  358. start_prefix = 'geo_vl_doc_model.'
  359. load(model, prefix=start_prefix)
  360. if len(missing_keys) > 0:
  361. logger.info(
  362. 'Weights of {} not initialized from pretrained model: {}'.format(
  363. model.__class__.__name__, missing_keys))
  364. if len(unexpected_keys) > 0:
  365. logger.info('Weights from pretrained model not used in {}: {}'.format(
  366. model.__class__.__name__, unexpected_keys))
  367. if len(error_msgs) > 0:
  368. raise RuntimeError(
  369. 'Error(s) in loading state_dict for {}:\n\t{}'.format(
  370. model.__class__.__name__, '\n\t'.join(error_msgs)))
  371. return model