| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import copy
- import logging
- import math
- import os
- import re
- import sys
- import json
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from torchvision.ops import roi_align
- from modelscope.metainfo import Models
- from modelscope.models import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.models.multi_modal.vldoc.conv_fpn_trans import FPNTrans
- from modelscope.models.multi_modal.vldoc.modeling_layout_roberta import (
- LayoutRobertaModel, LayoutRobertaPreTrainedModel)
- from modelscope.models.multi_modal.vldoc.transformer_local import (
- TransformerDecoder, TransformerDecoderLayer)
- from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- __all__ = ['VLDocForDocVLEmbedding']
- class GeoVLDocModelOutputs(object):
- def __init__(
- self,
- text_features,
- text_mm_features,
- block_vis_features,
- block_vis_mm_features,
- image_mm_features,
- ):
- # [batch size, sequence length, hidden size]
- self.text_features = text_features
- # [batch size, sequence length, hidden size]
- self.text_mm_features = text_mm_features
- # [batch size, block num, hidden size]
- self.block_vis_features = block_vis_features
- # [batch size, block num, hidden size]
- self.block_vis_mm_features = block_vis_mm_features
- # [batch size, hidden size]
- self.image_mm_features = image_mm_features
- class GeoVLDocModel(LayoutRobertaPreTrainedModel):
- def __init__(self, config, hard_negtive_sampling=False):
- super().__init__(config)
- self.config = config
- self.hard_negtive_sampling = hard_negtive_sampling
- if getattr(self.config, 'architectures', None):
- if self.config.architectures[0] == 'LayoutRobertaModel':
- self.text_encoder = LayoutRobertaModel(config)
- else:
- self.text_encoder = LayoutRobertaModel(config)
- else:
- self.text_encoder = LayoutRobertaModel(config)
- self.visual_encoder = FPNTrans(
- img_size=self.config.image_size, inner_vit=False)
- self.pool = nn.AdaptiveAvgPool2d([1, 1])
- self.vis_linear = nn.Linear(256, self.config.hidden_size)
- cross_modal_text_layer = TransformerDecoderLayer(
- self.config.hidden_size,
- self.config.num_attention_heads,
- self.config.intermediate_size,
- self_attn=True)
- self.cross_modal_text = TransformerDecoder(cross_modal_text_layer, 1)
- cross_modal_visual_layer = TransformerDecoderLayer(
- self.config.hidden_size,
- self.config.num_attention_heads,
- self.config.intermediate_size,
- self_attn=True)
- self.cross_modal_visual = TransformerDecoder(cross_modal_visual_layer,
- 1)
- self.init_weights()
- def from_pretrained(self, ckpt_path: str):
- state_dict = torch.load(ckpt_path, map_location='cpu')
- state_dict_new = {}
- for k, v in state_dict.items():
- k = k.replace('geo_vl_doc_model.', '')
- state_dict_new[k] = v
- self.load_state_dict(state_dict_new)
- def forward(self,
- input_ids=None,
- image=None,
- bbox=None,
- bbox_4p_normalized=None,
- attention_mask=None,
- first_token_idxes=None,
- first_token_idxes_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- **kwargs):
- batch_size, seq_len = input_ids.shape
- return_dict = (
- return_dict
- if return_dict is not None else self.config.use_return_dict)
- kwargs['line_bbox'] = bbox
- # ################ get text representation ################
- if self.config.architectures[0] == 'LayoutRobertaModel':
- outputs = self.text_encoder(
- input_ids,
- bbox=bbox_4p_normalized,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs)
- else:
- outputs = self.text_encoder(
- input_ids,
- bbox=bbox_4p_normalized,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs)
- # sequence_output: [batch_size, seq_len, hidden_size]
- # pooled_output: [batch_size, hidden_size]
- sequence_output, pooled_output = outputs[:2]
- # ################ get visual representation ################
- _, num_first = first_token_idxes.shape
- B_batch_dim = torch.arange(
- 0, batch_size,
- device=input_ids.device).reshape(batch_size,
- 1).expand(batch_size, num_first)
- feature_bbox = bbox[B_batch_dim, first_token_idxes]
- _, block_num, _ = feature_bbox.shape
- visual_out = self.visual_encoder(image)
- batch_idxs = torch.arange(
- 0, batch_size, device=sequence_output.device).reshape(
- batch_size, 1).expand(batch_size, block_num).unsqueeze(-1)
- # [batch_size*block_num, 5]
- batch_idx_with_bbox = torch.cat(
- (batch_idxs, feature_bbox),
- 2).reshape(batch_size * block_num,
- 5).to(dtype=visual_out['feat_ms'].dtype)
- if visual_out['feat_ms'].dtype == torch.float16:
- # [batch_size*block_num, 256, 1, 1]
- blk_vis_features = roi_align(
- visual_out['feat_ms'].to(torch.float32),
- batch_idx_with_bbox.to(torch.float32),
- 1,
- spatial_scale=visual_out['feat_ms'].size(-1) / 1000.0)
- blk_vis_features = blk_vis_features.to(
- dtype=visual_out['feat_ms'].dtype)
- else:
- blk_vis_features = roi_align(
- visual_out['feat_ms'],
- batch_idx_with_bbox.to(torch.float32),
- 1,
- spatial_scale=visual_out['feat_ms'].size(-1) / 1000.0)
- # [batch_size*block_num, 256]
- blk_vis_features = blk_vis_features.squeeze(2).squeeze(2).reshape(
- batch_size, block_num, 256)
- # visual block features:
- # blk_vis_features: [batch_size, block_num, hidden_size]
- blk_vis_features = self.vis_linear(blk_vis_features)
- blk_vis_features = blk_vis_features * first_token_idxes_mask.unsqueeze(
- 2)
- # [batch_size, 256]
- full_img_features = self.pool(
- visual_out['feat_ms']).squeeze(2).squeeze(2)
- # [batch_size, hidden_size]
- full_img_features = self.vis_linear(full_img_features).unsqueeze(1)
- # ################ multi-modal fusion ################
- # cross attention inputs
- vis_inps = torch.cat((full_img_features, blk_vis_features), 1)
- glb_feat_attn = torch.ones((batch_size, 1)).to(input_ids.device)
- vis_mask = torch.cat((glb_feat_attn, first_token_idxes_mask), 1)
- # When we use transformer in torch.nn, the input size is
- # [seq_len, batch_size, hidden_size]
- # In attention_mask, 1 denotes masked
- new_attention_mask = (1 - attention_mask) > 0
- new_vis_mask = (1 - vis_mask) > 0
- text_mm_feat = self.cross_modal_text(
- tgt=sequence_output.transpose(0, 1),
- memory=vis_inps.transpose(0, 1),
- tgt_key_padding_mask=new_attention_mask,
- memory_key_padding_mask=new_vis_mask)
- vis_mm_feat = self.cross_modal_visual(
- tgt=vis_inps.transpose(0, 1),
- memory=sequence_output.transpose(0, 1),
- tgt_key_padding_mask=new_vis_mask,
- memory_key_padding_mask=new_attention_mask,
- )
- # [batch_size, seq_len, hidden_size]
- text_mm_feat = text_mm_feat.transpose(0, 1)
- # [batch_size, 1+block_num, hidden_size]
- vis_mm_feat = vis_mm_feat.transpose(0, 1)
- # image_mm_features = vis_mm_feat[:, 0, :]
- block_vis_mm_features = vis_mm_feat[:, 1:]
- return GeoVLDocModelOutputs(
- text_features=sequence_output,
- text_mm_features=text_mm_feat,
- block_vis_features=blk_vis_features,
- block_vis_mm_features=block_vis_mm_features,
- image_mm_features=vis_mm_feat,
- )
- @MODELS.register_module(Tasks.document_vl_embedding, module_name=Models.vldoc)
- class VLDocForDocVLEmbedding(TorchModel):
- """
- Generate multi-modal document embeddings in segment-level and token-level.
- Args:
- model_dir:
- the path in model hub, e.g., 'damo/multi-modal_convnext-roberta-base_vldoc-embedding'
- """
- def __init__(self, model_dir: str, *args, **kwargs):
- super().__init__(model_dir=model_dir, *args, **kwargs)
- # Initialize the model.
- from modelscope.models.multi_modal.vldoc.modeling_layout_roberta import LayoutRobertaConfig
- model_cfg_path = os.path.join(model_dir, 'config.json')
- logger.info('Loading config file from {}'.format(model_cfg_path))
- assert os.path.exists(model_cfg_path)
- self.config = LayoutRobertaConfig.from_json_file(model_cfg_path)
- self.doc_model = GeoVLDocModel(self.config)
- # restore the pretrained weight
- model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- assert os.path.exists(model_path)
- self.doc_model.from_pretrained(model_path)
- logger.info('Loading model from {}'.format(model_path))
- # Initialize the tokenizer.
- from modelscope.models.multi_modal.vldoc.tokenization import VLDocXLMTokenizer
- tokenizer_path = os.path.join(model_dir, ModelFile.TOKENIZER_FOLDER)
- self.tokenizer = VLDocXLMTokenizer.from_pretrained(tokenizer_path)
- # place the model
- self.device = 'cuda:{}'.format(int(os.environ.get(
- 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu'
- if torch.cuda.is_available():
- self.doc_model.to(self.device)
- logger.info('Use GPU {} for finetuning & inference'.format(
- int(os.environ.get('LOCAL_RANK', 0))))
- else:
- self.doc_model.float()
- logger.info('Use CPU for finetuning & inference')
- def forward(self,
- input_ids=None,
- image=None,
- bbox=None,
- bbox_4p_normalized=None,
- attention_mask=None,
- first_token_idxes=None,
- first_token_idxes_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- **kwargs):
- """
- Args:
- - input_ids: :math:`(B, T, E)`, the input tokens, where B is the batch size,
- T is the max token size, E is the embedding dimension.
- - image: :math:`(B, C, H, W)`, normalized images.
- - bbox: :math:`(B, T, 4)`, segment boxes denoted by top-left and bottom-right
- vertexes whose values are normalized to [0, 1000).
- - bbox_4p_normalized: :math:`(B, T, 8)`, word boxes denoted by 4 vertexes, whose
- values are normalized to [0, 1).
- - attention_mask: :math:`(B, T)`, mask for input tokens, where 0 means masked.
- - first_token_idxes: :math:`(B, S)`, indexes of the corresponding first tokens
- of all segments, where S is the max segment size.
- - first_token_idxes_mask: :math:`(B, S)`, mask for segments, where 0 means masked.
- Optional:
- - line_rank_id: :math:`(B, T)`, orders of segments.
- - line_rank_inner_id: :math:`(B, T)`, BIE-like tags.
- To be more specific, please refer to the class `TextLayoutSerializer` in
- `modelscope/models/multi_modal/vldoc/processing.py`.
- """
- vldoc_outputs = self.doc_model(
- input_ids=input_ids,
- image=image,
- bbox=bbox,
- bbox_4p_normalized=bbox_4p_normalized,
- attention_mask=attention_mask,
- first_token_idxes=first_token_idxes,
- first_token_idxes_mask=first_token_idxes_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- **kwargs)
- return dict(
- img_embedding=vldoc_outputs.image_mm_features,
- text_embedding=vldoc_outputs.text_mm_features,
- )
- def init_pretrained_weight(
- model,
- pretrained_model_path,
- state_dict=None,
- cache_dir=None,
- init_backbone='roberta',
- ):
- if state_dict is None:
- state_dict = torch.load(pretrained_model_path, map_location='cpu')
- old_keys = []
- new_keys = []
- state_dict_keys = list(state_dict.keys())
- if init_backbone == 'roberta':
- for i in range(len(state_dict_keys)):
- key = state_dict_keys[i]
- new_key = None
- if key.startswith('roberta.'):
- new_key = key.replace('roberta.',
- 'geo_vl_doc_model.text_encoder.')
- key = copy.deepcopy(new_key)
- if new_key:
- old_keys.append(state_dict_keys[i])
- new_keys.append(new_key)
- for old_key, new_key in zip(old_keys, new_keys):
- state_dict[new_key] = state_dict.pop(old_key)
- missing_keys = []
- unexpected_keys = []
- error_msgs = []
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
- def load(module, prefix=''):
- local_metadata = {} if metadata is None else metadata.get(
- prefix[:-1], {})
- module._load_from_state_dict(state_dict, prefix, local_metadata, True,
- missing_keys, unexpected_keys, error_msgs)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
- start_prefix = ''
- if not hasattr(model, 'geo_vl_doc_model') and any(
- s.startswith('geo_vl_doc_model.') for s in state_dict.keys()):
- start_prefix = 'geo_vl_doc_model.'
- load(model, prefix=start_prefix)
- if len(missing_keys) > 0:
- logger.info(
- 'Weights of {} not initialized from pretrained model: {}'.format(
- model.__class__.__name__, missing_keys))
- if len(unexpected_keys) > 0:
- logger.info('Weights from pretrained model not used in {}: {}'.format(
- model.__class__.__name__, unexpected_keys))
- if len(error_msgs) > 0:
- raise RuntimeError(
- 'Error(s) in loading state_dict for {}:\n\t{}'.format(
- model.__class__.__name__, '\n\t'.join(error_msgs)))
- return model
|