language_recognition_pipeline.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. from typing import Union
  5. import numpy as np
  6. import soundfile as sf
  7. import torch
  8. import torchaudio
  9. from modelscope.fileio import File
  10. from modelscope.metainfo import Pipelines
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import InputModel, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. __all__ = ['LanguageRecognitionPipeline']
  18. @PIPELINES.register_module(
  19. Tasks.speech_language_recognition,
  20. module_name=Pipelines.speech_language_recognition)
  21. class LanguageRecognitionPipeline(Pipeline):
  22. """Language Recognition Inference Pipeline
  23. use `model` to create a Language Recognition pipeline.
  24. Args:
  25. model (LanguageRecognitionPipeline): A model instance, or a model local dir, or a model id in the model hub.
  26. kwargs (dict, `optional`):
  27. Extra kwargs passed into the pipeline's constructor.
  28. Example:
  29. >>> from modelscope.pipelines import pipeline
  30. >>> from modelscope.utils.constant import Tasks
  31. >>> p = pipeline(
  32. >>> task=Tasks.speech_language_recognition, model='damo/speech_campplus_lre_en-cn_16k')
  33. >>> print(p(audio_in))
  34. """
  35. def __init__(self, model: InputModel, **kwargs):
  36. """use `model` to create a Language Recognition pipeline for prediction
  37. Args:
  38. model (str): a valid official model id
  39. """
  40. super().__init__(model=model, **kwargs)
  41. self.model_config = self.model.model_config
  42. self.languages = self.model_config['languages']
  43. def __call__(self,
  44. in_audios: Union[str, list, np.ndarray],
  45. out_file: str = None):
  46. wavs = self.preprocess(in_audios)
  47. scores, results = self.forward(wavs)
  48. outputs = self.postprocess(results, scores, in_audios, out_file)
  49. return outputs
  50. def forward(self, inputs: list):
  51. scores = []
  52. results = []
  53. for x in inputs:
  54. score, result = self.model(x)
  55. scores.append(score.tolist())
  56. results.append(result.item())
  57. return scores, results
  58. def postprocess(self,
  59. inputs: list,
  60. scores: list,
  61. in_audios: Union[str, list, np.ndarray],
  62. out_file=None):
  63. if isinstance(in_audios, str):
  64. output = {
  65. OutputKeys.TEXT: self.languages[inputs[0]],
  66. OutputKeys.SCORE: scores
  67. }
  68. else:
  69. output = {
  70. OutputKeys.TEXT: [self.languages[i] for i in inputs],
  71. OutputKeys.SCORE: scores
  72. }
  73. if out_file is not None:
  74. out_lines = []
  75. for i, audio in enumerate(in_audios):
  76. if isinstance(audio, str):
  77. audio_id = os.path.basename(audio).rsplit('.', 1)[0]
  78. else:
  79. audio_id = i
  80. out_lines.append('%s %s\n' %
  81. (audio_id, self.languages[inputs[i]]))
  82. with open(out_file, 'w') as f:
  83. for i in out_lines:
  84. f.write(i)
  85. return output
  86. def preprocess(self, inputs: Union[str, list, np.ndarray]):
  87. output = []
  88. if isinstance(inputs, str):
  89. file_bytes = File.read(inputs)
  90. data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
  91. if len(data.shape) == 2:
  92. data = data[:, 0]
  93. data = torch.from_numpy(data).unsqueeze(0)
  94. if fs != self.model_config['sample_rate']:
  95. logger.warning(
  96. 'The sample rate of audio is not %d, resample it.'
  97. % self.model_config['sample_rate'])
  98. data, fs = torchaudio.sox_effects.apply_effects_tensor(
  99. data,
  100. fs,
  101. effects=[['rate',
  102. str(self.model_config['sample_rate'])]])
  103. data = data.squeeze(0)
  104. output.append(data)
  105. else:
  106. for i in range(len(inputs)):
  107. if isinstance(inputs[i], str):
  108. file_bytes = File.read(inputs[i])
  109. data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
  110. if len(data.shape) == 2:
  111. data = data[:, 0]
  112. data = torch.from_numpy(data).unsqueeze(0)
  113. if fs != self.model_config['sample_rate']:
  114. logger.warning(
  115. 'The sample rate of audio is not %d, resample it.'
  116. % self.model_config['sample_rate'])
  117. data, fs = torchaudio.sox_effects.apply_effects_tensor(
  118. data,
  119. fs,
  120. effects=[[
  121. 'rate',
  122. str(self.model_config['sample_rate'])
  123. ]])
  124. data = data.squeeze(0)
  125. elif isinstance(inputs[i], np.ndarray):
  126. assert len(
  127. inputs[i].shape
  128. ) == 1, 'modelscope error: Input array should be [N, T]'
  129. data = inputs[i]
  130. if data.dtype in ['int16', 'int32', 'int64']:
  131. data = (data / (1 << 15)).astype('float32')
  132. else:
  133. data = data.astype('float32')
  134. data = torch.from_numpy(data)
  135. else:
  136. raise ValueError(
  137. 'modelscope error: The input type is restricted to audio address and nump array.'
  138. )
  139. output.append(data)
  140. return output