asr.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. from pathlib import Path
  5. from typing import Any, Dict
  6. import librosa
  7. import soundfile as sf
  8. import torch
  9. from fairseq.data.audio.feature_transforms import \
  10. CompositeAudioFeatureTransform
  11. from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
  12. from modelscope.utils.chinese_utils import pre_chinese
  13. from modelscope.utils.constant import ModeKeys
  14. from .base import OfaBasePreprocessor
  15. from .utils.text2phone import Text2Phone
  16. class OfaASRPreprocessor(OfaBasePreprocessor):
  17. def __init__(self,
  18. cfg,
  19. model_dir,
  20. mode=ModeKeys.INFERENCE,
  21. *args,
  22. **kwargs):
  23. """preprocess the data
  24. Args:
  25. cfg(modelscope.utils.config.ConfigDict) : model config
  26. model_dir (str): model path,
  27. mode: preprocessor mode (model mode)
  28. """
  29. super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args,
  30. **kwargs)
  31. # Initialize transform
  32. self.data_cfg = S2TDataConfig(
  33. Path(os.path.join(model_dir, 'fbank_config.yaml')))
  34. self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
  35. self.data_cfg.get_feature_transforms('train', True))
  36. self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
  37. self.data_cfg.get_feature_transforms('test', False))
  38. self.text2phone_tokenizer = Text2Phone(
  39. os.path.join(model_dir, 'text2phone_dict.txt'))
  40. self.phone_to_id, self.id_to_phone = self.build_phone_dict(
  41. os.path.join(model_dir, 'phone_dict.txt'))
  42. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  43. if self.mode == ModeKeys.TRAIN:
  44. return self._build_train_sample(data)
  45. else:
  46. return self._build_infer_sample(data)
  47. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  48. speed = random.choice([0.9, 1.0, 1.1])
  49. audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
  50. wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
  51. fbank = self.prepare_fbank(
  52. torch.tensor([wav], dtype=torch.float32),
  53. sr,
  54. speed,
  55. target_sample_rate=16000,
  56. is_train=True)
  57. fbank_mask = torch.tensor([True])
  58. sample = {
  59. 'fbank': fbank,
  60. 'fbank_mask': fbank_mask,
  61. 'label': data[self.column_map['text']]
  62. }
  63. target = sample['label']
  64. if self.language == 'zh':
  65. target = pre_chinese(target, self.max_tgt_length)
  66. sample['target'] = self.tokenize_text(target, add_bos=False)
  67. else:
  68. target = target.translate(self.transtab).strip()
  69. target_token_list = target.strip().split()
  70. target = ' '.join(target_token_list[:self.max_tgt_length])
  71. sample['target'] = self.tokenize_text(target, add_bos=False)
  72. phone_item = self.to_phone(target) + 1
  73. phone_mask = torch.tensor([False])
  74. sample['phone_item'] = phone_item + 3
  75. sample['phone_target'] = phone_item
  76. sample['phone_mask'] = phone_mask
  77. sample['prev_output_tokens'] = torch.cat(
  78. [self.bos_item, sample['target'][:-1]])
  79. return sample
  80. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  81. speed = 1.0
  82. audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
  83. wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
  84. fbank = self.prepare_fbank(
  85. torch.tensor([wav], dtype=torch.float32),
  86. sr,
  87. speed,
  88. target_sample_rate=16000,
  89. is_train=False)
  90. fbank_mask = torch.tensor([True])
  91. sample = {'fbank': fbank, 'fbank_mask': fbank_mask}
  92. if 'text' in self.column_map and self.column_map['text'] in data:
  93. sample['label'] = data[self.column_map['text']]
  94. # mock
  95. sample['phone_item'] = torch.tensor([6, 6, 6])
  96. sample['phone_mask'] = torch.tensor([False])
  97. return sample
  98. def to_phone(self, text):
  99. phones = self.text2phone_tokenizer.trans(text)
  100. ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')])
  101. return ids
  102. def build_phone_dict(self, phone_dict_path):
  103. phone_to_id = dict()
  104. id_to_phone = dict()
  105. with open(phone_dict_path, 'r') as phone_dict_file:
  106. for i, line in enumerate(phone_dict_file):
  107. phone = line.strip().split(' ')[0]
  108. phone_to_id[phone] = i
  109. id_to_phone[i] = phone_to_id
  110. return phone_to_id, id_to_phone