| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import random
- from pathlib import Path
- from typing import Any, Dict
- import librosa
- import soundfile as sf
- import torch
- from fairseq.data.audio.feature_transforms import \
- CompositeAudioFeatureTransform
- from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
- from modelscope.utils.chinese_utils import pre_chinese
- from modelscope.utils.constant import ModeKeys
- from .base import OfaBasePreprocessor
- from .utils.text2phone import Text2Phone
- class OfaASRPreprocessor(OfaBasePreprocessor):
- def __init__(self,
- cfg,
- model_dir,
- mode=ModeKeys.INFERENCE,
- *args,
- **kwargs):
- """preprocess the data
- Args:
- cfg(modelscope.utils.config.ConfigDict) : model config
- model_dir (str): model path,
- mode: preprocessor mode (model mode)
- """
- super(OfaASRPreprocessor, self).__init__(cfg, model_dir, mode, *args,
- **kwargs)
- # Initialize transform
- self.data_cfg = S2TDataConfig(
- Path(os.path.join(model_dir, 'fbank_config.yaml')))
- self.train_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
- self.data_cfg.get_feature_transforms('train', True))
- self.test_audio_feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
- self.data_cfg.get_feature_transforms('test', False))
- self.text2phone_tokenizer = Text2Phone(
- os.path.join(model_dir, 'text2phone_dict.txt'))
- self.phone_to_id, self.id_to_phone = self.build_phone_dict(
- os.path.join(model_dir, 'phone_dict.txt'))
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
- if self.mode == ModeKeys.TRAIN:
- return self._build_train_sample(data)
- else:
- return self._build_infer_sample(data)
- def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- speed = random.choice([0.9, 1.0, 1.1])
- audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
- wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
- fbank = self.prepare_fbank(
- torch.tensor([wav], dtype=torch.float32),
- sr,
- speed,
- target_sample_rate=16000,
- is_train=True)
- fbank_mask = torch.tensor([True])
- sample = {
- 'fbank': fbank,
- 'fbank_mask': fbank_mask,
- 'label': data[self.column_map['text']]
- }
- target = sample['label']
- if self.language == 'zh':
- target = pre_chinese(target, self.max_tgt_length)
- sample['target'] = self.tokenize_text(target, add_bos=False)
- else:
- target = target.translate(self.transtab).strip()
- target_token_list = target.strip().split()
- target = ' '.join(target_token_list[:self.max_tgt_length])
- sample['target'] = self.tokenize_text(target, add_bos=False)
- phone_item = self.to_phone(target) + 1
- phone_mask = torch.tensor([False])
- sample['phone_item'] = phone_item + 3
- sample['phone_target'] = phone_item
- sample['phone_mask'] = phone_mask
- sample['prev_output_tokens'] = torch.cat(
- [self.bos_item, sample['target'][:-1]])
- return sample
- def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- speed = 1.0
- audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
- wav, sr = librosa.load(audio_bytes, sr=16000, mono=True)
- fbank = self.prepare_fbank(
- torch.tensor([wav], dtype=torch.float32),
- sr,
- speed,
- target_sample_rate=16000,
- is_train=False)
- fbank_mask = torch.tensor([True])
- sample = {'fbank': fbank, 'fbank_mask': fbank_mask}
- if 'text' in self.column_map and self.column_map['text'] in data:
- sample['label'] = data[self.column_map['text']]
- # mock
- sample['phone_item'] = torch.tensor([6, 6, 6])
- sample['phone_mask'] = torch.tensor([False])
- return sample
- def to_phone(self, text):
- phones = self.text2phone_tokenizer.trans(text)
- ids = torch.tensor([self.phone_to_id[x] for x in phones.split(' ')])
- return ids
- def build_phone_dict(self, phone_dict_path):
- phone_to_id = dict()
- id_to_phone = dict()
- with open(phone_dict_path, 'r') as phone_dict_file:
- for i, line in enumerate(phone_dict_file):
- phone = line.strip().split(' ')[0]
- phone_to_id[phone] = i
- id_to_phone[i] = phone_to_id
- return phone_to_id, id_to_phone
|