ocr_recognition_trainer.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from collections.abc import Mapping
  4. import torch
  5. from torch import distributed as dist
  6. from modelscope.metainfo import Trainers
  7. from modelscope.trainers.builder import TRAINERS
  8. from modelscope.trainers.trainer import EpochBasedTrainer
  9. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
  10. ConfigKeys, Hubs, ModeKeys, ModelFile,
  11. Tasks, TrainerStages)
  12. from modelscope.utils.data_utils import to_device
  13. from modelscope.utils.file_utils import func_receive_dict_inputs
  14. @TRAINERS.register_module(module_name=Trainers.ocr_recognition)
  15. class OCRRecognitionTrainer(EpochBasedTrainer):
  16. def evaluate(self, *args, **kwargs):
  17. metric_values = super().evaluate(*args, **kwargs)
  18. return metric_values
  19. def prediction_step(self, model, inputs):
  20. pass
  21. def train_step(self, model, inputs):
  22. """ Perform a training step on a batch of inputs.
  23. Subclass and override to inject custom behavior.
  24. Args:
  25. model (`TorchModel`): The model to train.
  26. inputs (`Dict[str, Union[torch.Tensor, Any]]`):
  27. The inputs and targets of the model.
  28. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
  29. argument `labels`. Check your model's documentation for all accepted arguments.
  30. Return:
  31. `torch.Tensor`: The tensor with training loss on this batch.
  32. """
  33. # EvaluationHook will do evaluate and change mode to val, return to train mode
  34. # TODO: find more pretty way to change mode
  35. model.train()
  36. self._mode = ModeKeys.TRAIN
  37. train_outputs = model.do_step(inputs)
  38. if not isinstance(train_outputs, dict):
  39. raise TypeError('"model.forward()" must return a dict')
  40. # add model output info to log
  41. if 'log_vars' not in train_outputs:
  42. default_keys_pattern = ['loss']
  43. match_keys = set([])
  44. for key_p in default_keys_pattern:
  45. match_keys.update(
  46. [key for key in train_outputs.keys() if key_p in key])
  47. log_vars = {}
  48. for key in match_keys:
  49. value = train_outputs.get(key, None)
  50. if value is not None:
  51. if dist.is_available() and dist.is_initialized():
  52. value = value.data.clone()
  53. dist.all_reduce(value.div_(dist.get_world_size()))
  54. log_vars.update({key: value.item()})
  55. self.log_buffer.update(log_vars)
  56. else:
  57. self.log_buffer.update(train_outputs['log_vars'])
  58. self.train_outputs = train_outputs
  59. def evaluation_step(self, data):
  60. """Perform a evaluation step on a batch of inputs.
  61. Subclass and override to inject custom behavior.
  62. """
  63. model = self.model.module if self._dist else self.model
  64. model.eval()
  65. result = model.do_step(data)
  66. return result