utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from collections.abc import Mapping
  4. from typing import Any, Dict, List, Tuple, Union
  5. import json
  6. import numpy as np
  7. from transformers import AutoTokenizer
  8. from modelscope.metainfo import Models
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.preprocessors.base import Preprocessor
  11. from modelscope.utils.constant import ModeKeys
  12. from modelscope.utils.hub import get_model_type, parse_label_mapping
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. __all__ = ['parse_text_and_label', 'labels_to_id']
  16. def parse_text_and_label(data,
  17. mode,
  18. first_sequence=None,
  19. second_sequence=None,
  20. label=None):
  21. """Parse the input and return the sentences and labels.
  22. When input type is tuple or list and its size is 2:
  23. If the pair param is False, data will be parsed as the first_sentence and the label,
  24. else it will be parsed as the first_sentence and the second_sentence.
  25. Args:
  26. data: The input data.
  27. mode: The mode of the preprocessor
  28. first_sequence: The key of the first sequence
  29. second_sequence: The key of the second sequence
  30. label: The key of the label
  31. Returns:
  32. The sentences and labels tuple.
  33. """
  34. text_a, text_b, labels = None, None, None
  35. if isinstance(data, str):
  36. text_a = data
  37. elif isinstance(data, tuple) or isinstance(data, list):
  38. if len(data) == 3:
  39. text_a, text_b, labels = data
  40. elif len(data) == 2:
  41. if mode == ModeKeys.INFERENCE:
  42. text_a, text_b = data
  43. else:
  44. text_a, labels = data
  45. elif isinstance(data, Mapping):
  46. text_a = data.get(first_sequence)
  47. text_b = data.get(second_sequence)
  48. if label is None or isinstance(label, str):
  49. labels = data.get(label)
  50. else:
  51. labels = [data.get(lb) for lb in label]
  52. return text_a, text_b, labels
  53. def labels_to_id(labels, output, label2id=None):
  54. """Turn the labels to id with the type int or float.
  55. If the original label's type is str or int, the label2id mapping will try to convert it to the final label.
  56. If the original label's type is float, or the label2id mapping does not exist,
  57. the original label will be returned.
  58. Args:
  59. label2id: An extra label2id mapping. If not provided, the label will not be translated to ids.
  60. labels: The input labels.
  61. output: The label id.
  62. Returns:
  63. The final labels.
  64. """
  65. def label_can_be_mapped(label):
  66. return isinstance(label, str) or isinstance(label, int)
  67. try:
  68. if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \
  69. and label2id is not None:
  70. output[OutputKeys.LABELS] = [
  71. label2id[label] if label in label2id else label2id[str(label)]
  72. for label in labels
  73. ]
  74. elif label_can_be_mapped(labels) and label2id is not None:
  75. output[OutputKeys.LABELS] = label2id[
  76. labels] if labels in label2id else label2id[str(labels)]
  77. elif labels is not None:
  78. output[OutputKeys.LABELS] = labels
  79. except KeyError as e:
  80. logger.error(
  81. f'Label {labels} cannot be found in the label mapping {label2id},'
  82. f'which comes from the user input or the configuration files. '
  83. f'Please consider matching your labels with this mapping.')
  84. raise e