data_collators.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from huggingface/transformers.
  3. from collections import OrderedDict
  4. from collections.abc import Mapping
  5. from typing import Any, List, Optional, Tuple
  6. from .logger import get_logger
  7. logger = get_logger()
  8. class RemoveColumnsCollator:
  9. """Remove specified columns from the input mini-batch, and convert them to attributes.
  10. For example: if columns_to_remove = ['id'], then user should call batch.id instead of batch['id'].
  11. Args:
  12. data_collator: An inner data collator to collate the mini-batch
  13. columns_to_remove(`List[str]`): The redundant columns to be removed from the mini-batch
  14. model_name(`Optional[str]`): An optional model name to print into log
  15. description(`Optional[str]`): An optional description to print into log
  16. """
  17. def __init__(
  18. self,
  19. data_collator,
  20. columns_to_remove: List[str],
  21. model_name: Optional[str] = None,
  22. description: Optional[str] = None,
  23. ):
  24. self.data_collator = data_collator
  25. self.columns_to_remove = columns_to_remove
  26. self.description = description
  27. self.model_name = model_name
  28. self.message_logged = False
  29. def _remove_columns(self, feature: Mapping) -> Tuple[Mapping, Any]:
  30. if not isinstance(feature, Mapping):
  31. return feature, None
  32. if not self.message_logged and self.model_name:
  33. ignored_columns = list(
  34. set(feature.keys()) - set(self.columns_to_remove))
  35. if len(ignored_columns) > 0:
  36. dset_description = '' if self.description is None else f'in the {self.description} set'
  37. logger.info(
  38. f"The following columns {dset_description} don't have a corresponding argument in "
  39. f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
  40. f"Legal columns: {', '.join(self.columns_to_remove)}."
  41. f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
  42. ' you can safely ignore this message.')
  43. self.message_logged = True
  44. feature_clean = {
  45. k: v
  46. for k, v in feature.items() if k in self.columns_to_remove
  47. }
  48. feature_unused = {
  49. k: v
  50. for k, v in feature.items() if k not in self.columns_to_remove
  51. }
  52. return feature_clean, feature_unused
  53. def __call__(self, features: List[Mapping]):
  54. features_clean = []
  55. features_unused = []
  56. for feature in features:
  57. feature, feature_unused = self._remove_columns(feature)
  58. features_clean.append(feature)
  59. features_unused.append(feature_unused)
  60. data = OrderedDict(self.data_collator(features_clean))
  61. if features_unused[0] is not None:
  62. for key in features_unused[0].keys():
  63. setattr(data, key, [
  64. feature_unused[key] for feature_unused in features_unused
  65. ])
  66. return data