| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import time
- from collections.abc import Mapping
- import torch
- from torch import distributed as dist
- from modelscope.metainfo import Trainers
- from modelscope.trainers.builder import TRAINERS
- from modelscope.trainers.trainer import EpochBasedTrainer
- from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
- ConfigKeys, Hubs, ModeKeys, ModelFile,
- Tasks, TrainerStages)
- from modelscope.utils.data_utils import to_device
- from modelscope.utils.file_utils import func_receive_dict_inputs
- @TRAINERS.register_module(module_name=Trainers.ocr_recognition)
- class OCRRecognitionTrainer(EpochBasedTrainer):
- def evaluate(self, *args, **kwargs):
- metric_values = super().evaluate(*args, **kwargs)
- return metric_values
- def prediction_step(self, model, inputs):
- pass
- def train_step(self, model, inputs):
- """ Perform a training step on a batch of inputs.
- Subclass and override to inject custom behavior.
- Args:
- model (`TorchModel`): The model to train.
- inputs (`Dict[str, Union[torch.Tensor, Any]]`):
- The inputs and targets of the model.
- The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
- argument `labels`. Check your model's documentation for all accepted arguments.
- Return:
- `torch.Tensor`: The tensor with training loss on this batch.
- """
- # EvaluationHook will do evaluate and change mode to val, return to train mode
- # TODO: find more pretty way to change mode
- model.train()
- self._mode = ModeKeys.TRAIN
- train_outputs = model.do_step(inputs)
- if not isinstance(train_outputs, dict):
- raise TypeError('"model.forward()" must return a dict')
- # add model output info to log
- if 'log_vars' not in train_outputs:
- default_keys_pattern = ['loss']
- match_keys = set([])
- for key_p in default_keys_pattern:
- match_keys.update(
- [key for key in train_outputs.keys() if key_p in key])
- log_vars = {}
- for key in match_keys:
- value = train_outputs.get(key, None)
- if value is not None:
- if dist.is_available() and dist.is_initialized():
- value = value.data.clone()
- dist.all_reduce(value.div_(dist.get_world_size()))
- log_vars.update({key: value.item()})
- self.log_buffer.update(log_vars)
- else:
- self.log_buffer.update(train_outputs['log_vars'])
- self.train_outputs = train_outputs
- def evaluation_step(self, data):
- """Perform a evaluation step on a batch of inputs.
- Subclass and override to inject custom behavior.
- """
- model = self.model.module if self._dist else self.model
- model.eval()
- result = model.do_step(data)
- return result
|