common.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from collections.abc import Sequence
  4. from typing import Mapping
  5. import numpy as np
  6. import torch
  7. from modelscope.utils.registry import default_group
  8. from .builder import PREPROCESSORS, build_preprocessor
  9. @PREPROCESSORS.register_module()
  10. class Compose(object):
  11. """Compose a data pipeline with a sequence of transforms.
  12. Args:
  13. transforms (list[dict | callable]):
  14. Either config dicts of transforms or transform objects.
  15. profiling (bool, optional): If set True, will profile and
  16. print preprocess time for each step.
  17. """
  18. def __init__(self, transforms, field_name=None, profiling=False):
  19. assert isinstance(transforms, Sequence)
  20. self.profiling = profiling
  21. self.transforms = []
  22. self.field_name = field_name
  23. for transform in transforms:
  24. if isinstance(transform, dict):
  25. if self.field_name is None:
  26. transform = build_preprocessor(transform, default_group)
  27. else:
  28. # if not found key in field_name, try field_name=None(default_group)
  29. try:
  30. transform = build_preprocessor(transform, field_name)
  31. except KeyError:
  32. transform = build_preprocessor(transform,
  33. default_group)
  34. elif callable(transform):
  35. pass
  36. else:
  37. raise TypeError('transform must be callable or a dict, but got'
  38. f' {type(transform)}')
  39. self.transforms.append(transform)
  40. def __call__(self, data):
  41. for t in self.transforms:
  42. if self.profiling:
  43. start = time.time()
  44. data = t(data)
  45. if self.profiling:
  46. print(f'{t} time {time.time()-start}')
  47. if data is None:
  48. return None
  49. return data
  50. def __repr__(self):
  51. format_string = self.__class__.__name__ + '('
  52. for t in self.transforms:
  53. format_string += f'\n {t}'
  54. format_string += '\n)'
  55. return format_string
  56. def to_tensor(data):
  57. """Convert objects of various python types to :obj:`torch.Tensor`.
  58. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
  59. :class:`Sequence`, :class:`int` and :class:`float`.
  60. Args:
  61. data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
  62. be converted.
  63. """
  64. if isinstance(data, torch.Tensor):
  65. return data
  66. elif isinstance(data, np.ndarray):
  67. return torch.from_numpy(data)
  68. elif isinstance(data, Sequence) and not isinstance(data, str):
  69. return torch.tensor(data)
  70. elif isinstance(data, int):
  71. return torch.LongTensor([data])
  72. elif isinstance(data, float):
  73. return torch.FloatTensor([data])
  74. else:
  75. raise TypeError(f'type {type(data)} cannot be converted to tensor.')
  76. @PREPROCESSORS.register_module()
  77. class ToTensor(object):
  78. """Convert target object to tensor.
  79. Args:
  80. keys (Sequence[str]): Key of data to be converted to Tensor.
  81. Only valid when data is type of `Mapping`. If `keys` is None,
  82. all values of keys ​​will be converted to tensor by default.
  83. """
  84. def __init__(self, keys=None):
  85. self.keys = keys
  86. def __call__(self, data):
  87. if isinstance(data, Mapping):
  88. if self.keys is None:
  89. self.keys = list(data.keys())
  90. for key in self.keys:
  91. if key in data:
  92. data[key] = to_tensor(data[key])
  93. else:
  94. data = to_tensor(data)
  95. return data
  96. def __repr__(self):
  97. return self.__class__.__name__ + f'(keys={self.keys})'
  98. @PREPROCESSORS.register_module()
  99. class Filter(object):
  100. """This is usually the last stage of the dataloader transform.
  101. Only data of reserved keys will be kept and passed directly to the model, others will be removed.
  102. Args:
  103. keys (Sequence[str]): Keys of data to be reserved, others will be removed.
  104. """
  105. def __init__(self, reserved_keys):
  106. self.reserved_keys = reserved_keys
  107. def __call__(self, data):
  108. assert isinstance(data, Mapping)
  109. reserved_data = {}
  110. for key in self.reserved_keys:
  111. if key in data:
  112. reserved_data[key] = data[key]
  113. return reserved_data
  114. def __repr__(self):
  115. return self.__class__.__name__ + f'(keys={self.reserved_keys})'
  116. def to_numpy(data):
  117. """Convert objects of various python types to `numpy.ndarray`.
  118. Args:
  119. data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
  120. be converted.
  121. """
  122. if isinstance(data, torch.Tensor):
  123. return data.numpy()
  124. elif isinstance(data, np.ndarray):
  125. return data
  126. elif isinstance(data, Sequence) and not isinstance(data, str):
  127. return np.asarray(data)
  128. elif isinstance(data, int):
  129. return np.asarray(data, dtype=np.int64)
  130. elif isinstance(data, float):
  131. return np.asarray(data, dtype=np.float64)
  132. else:
  133. raise TypeError(f'type {type(data)} cannot be converted to tensor.')
  134. @PREPROCESSORS.register_module()
  135. class ToNumpy(object):
  136. """Convert target object to numpy.ndarray.
  137. Args:
  138. keys (Sequence[str]): Key of data to be converted to numpy.ndarray.
  139. Only valid when data is type of `Mapping`. If `keys` is None,
  140. all values of keys ​​will be converted to numpy.ndarray by default.
  141. """
  142. def __init__(self, keys=None):
  143. self.keys = keys
  144. def __call__(self, data):
  145. if isinstance(data, Mapping):
  146. if self.keys is None:
  147. self.keys = list(data.keys())
  148. for key in self.keys:
  149. if key in data:
  150. data[key] = to_numpy(data[key])
  151. else:
  152. data = to_numpy(data)
  153. return data
  154. def __repr__(self):
  155. return self.__class__.__name__ + f'(keys={self.keys})'
  156. @PREPROCESSORS.register_module()
  157. class Rename(object):
  158. """Change the name of the input keys to output keys, respectively.
  159. """
  160. def __init__(self, input_keys=[], output_keys=[]):
  161. self.input_keys = input_keys
  162. self.output_keys = output_keys
  163. def __call__(self, data):
  164. if isinstance(data, Mapping):
  165. for in_key, out_key in zip(self.input_keys, self.output_keys):
  166. if in_key in data and out_key not in data:
  167. data[out_key] = data[in_key]
  168. data.pop(in_key)
  169. return data
  170. def __repr__(self):
  171. return self.__class__.__name__ + f'(keys={self.keys})'
  172. @PREPROCESSORS.register_module()
  173. class Identity(object):
  174. def __init__(self):
  175. pass
  176. def __call__(self, item):
  177. return item