speech_tts_autolabel.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import argparse
  2. import os
  3. import sys
  4. import zipfile
  5. from modelscope.hub.check_model import check_local_model_is_latest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.utils.constant import ThirdParty
  8. from modelscope.utils.logger import get_logger
  9. try:
  10. from tts_autolabel import AutoLabeling
  11. except ImportError:
  12. raise ImportError('pls install tts-autolabel with \
  13. "pip install tts-autolabel -f \
  14. https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html"'
  15. )
  16. DEFAULT_RESOURCE_MODEL_ID = 'damo/speech_ptts_autolabel_16k'
  17. logger = get_logger()
  18. # Suggest params:
  19. # --para_ids all --resource_revision v1.0.2 --input_wav data/test/audios/autolabel
  20. # --work_dir ../ptts/test/diff2 --develop_mode 1 --stage 1 --process_num 2 --no_para --disable_enh
  21. def run_auto_label(input_wav,
  22. work_dir,
  23. para_ids='all',
  24. resource_model_id=DEFAULT_RESOURCE_MODEL_ID,
  25. resource_revision=None,
  26. gender='female',
  27. stage=1,
  28. process_num=4,
  29. develop_mode=0,
  30. has_para=False,
  31. enable_enh=False):
  32. if not os.path.exists(input_wav):
  33. raise ValueError(f'input_wav: {input_wav} not exists')
  34. if not os.path.exists(work_dir):
  35. raise ValueError(f'work_dir: {work_dir} not exists')
  36. def _download_and_unzip_resource(model, model_revision=None):
  37. if os.path.exists(model):
  38. model_cache_dir = model if os.path.isdir(
  39. model) else os.path.dirname(model)
  40. check_local_model_is_latest(
  41. model_cache_dir,
  42. user_agent={ThirdParty.KEY: 'speech_tts_autolabel'})
  43. else:
  44. model_cache_dir = snapshot_download(
  45. model,
  46. revision=model_revision,
  47. user_agent={ThirdParty.KEY: 'speech_tts_autolabel'})
  48. if not os.path.exists(model_cache_dir):
  49. raise ValueError(f'model_cache_dir: {model_cache_dir} not exists')
  50. zip_file = os.path.join(model_cache_dir, 'model.zip')
  51. if not os.path.exists(zip_file):
  52. raise ValueError(f'zip_file: {zip_file} not exists')
  53. z = zipfile.ZipFile(zip_file)
  54. z.extractall(model_cache_dir)
  55. target_resource = os.path.join(model_cache_dir, 'model')
  56. return target_resource
  57. model_resource = _download_and_unzip_resource(resource_model_id,
  58. resource_revision)
  59. auto_labeling = AutoLabeling(
  60. os.path.abspath(input_wav),
  61. model_resource,
  62. False,
  63. os.path.abspath(work_dir),
  64. gender,
  65. develop_mode,
  66. has_para,
  67. para_ids,
  68. stage,
  69. process_num,
  70. enable_enh=enable_enh)
  71. ret_code, report = auto_labeling.run()
  72. return ret_code, report
  73. if __name__ == '__main__':
  74. parser = argparse.ArgumentParser()
  75. parser.add_argument(
  76. '--para_ids',
  77. default='all',
  78. help=
  79. 'you can use this variable to config your auto labeling paragraph ids, \
  80. all means all in the dir, none means no paragraph 1 means 1 para only, \
  81. 1 2 means 1 and 2, transcipt/prosody/wav should be named exactly the same!!!'
  82. )
  83. parser.add_argument(
  84. '--resource', type=str, default=DEFAULT_RESOURCE_MODEL_ID)
  85. parser.add_argument(
  86. '--resource_revision',
  87. type=str,
  88. default=None,
  89. help='resource directory')
  90. parser.add_argument('--input_wav', help='personal user input wav dir')
  91. parser.add_argument('--work_dir', help='autolabel work dir')
  92. parser.add_argument(
  93. '--gender', default='female', help='personal user gender')
  94. parser.add_argument('--develop_mode', type=int, default=1)
  95. parser.add_argument(
  96. '--stage',
  97. type=int,
  98. default=1,
  99. help='auto labeling stage, 0 means qualification and 1 means labeling')
  100. parser.add_argument(
  101. '--process_num',
  102. type=int,
  103. default=4,
  104. help='kaldi bin parallel execution process number')
  105. parser.add_argument(
  106. '--has_para', dest='has_para', action='store_true', help='paragraph')
  107. parser.add_argument(
  108. '--no_para',
  109. dest='has_para',
  110. action='store_false',
  111. help='no paragraph')
  112. parser.add_argument(
  113. '--enable_enh',
  114. dest='enable_enh',
  115. action='store_true',
  116. help='enable audio enhancement')
  117. parser.add_argument(
  118. '--disable_enh',
  119. dest='enable_enh',
  120. action='store_false',
  121. help='disable audio enhancement')
  122. parser.set_defaults(has_para=True)
  123. parser.set_defaults(enable_enh=False)
  124. args = parser.parse_args()
  125. logger.info(args.enable_enh)
  126. ret_code, report = run_auto_label(args.input_wav, args.work_dir,
  127. args.para_ids, args.resource,
  128. args.resource_revision, args.gender,
  129. args.stage, args.process_num,
  130. args.develop_mode, args.has_para,
  131. args.enable_enh)
  132. logger.info(f'ret_code={ret_code}')
  133. logger.info(f'report={report}')