text2phone.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.utils.chinese_utils import normalize_chinese_number
  3. class TrieNode(object):
  4. def __init__(self):
  5. """
  6. Initialize your data structure here.
  7. """
  8. self.data = {}
  9. self.is_word = False
  10. class Trie(object):
  11. """
  12. trie-tree
  13. """
  14. def __init__(self):
  15. """
  16. Initialize your data structure here.
  17. """
  18. self.root = TrieNode()
  19. def insert(self, word):
  20. """
  21. Inserts a word into the trie.
  22. :type word: str
  23. :rtype: void
  24. """
  25. node = self.root
  26. for chars in word:
  27. child = node.data.get(chars)
  28. if not child:
  29. node.data[chars] = TrieNode()
  30. node = node.data[chars]
  31. node.is_word = True
  32. def search(self, word):
  33. """
  34. Returns if the word is in the trie.
  35. :type word: str
  36. :rtype: bool
  37. """
  38. node = self.root
  39. for chars in word:
  40. node = node.data.get(chars)
  41. if not node:
  42. return False
  43. return node.is_word
  44. def startsWith(self, prefix):
  45. """
  46. Returns if there is any word in the trie that starts with the given prefix.
  47. :type prefix: str
  48. :rtype: bool
  49. """
  50. node = self.root
  51. for chars in prefix:
  52. node = node.data.get(chars)
  53. if not node:
  54. return False
  55. return True
  56. def get_start(self, prefix):
  57. """
  58. Returns words started with prefix
  59. :param prefix:
  60. :return: words (list)
  61. """
  62. def get_key(pre, pre_node):
  63. word_list = []
  64. if pre_node.is_word:
  65. word_list.append(pre)
  66. for x in pre_node.data.keys():
  67. word_list.extend(get_key(pre + str(x), pre_node.data.get(x)))
  68. return word_list
  69. words = []
  70. if not self.startsWith(prefix):
  71. return words
  72. if self.search(prefix):
  73. words.append(prefix)
  74. return words
  75. node = self.root
  76. for chars in prefix:
  77. node = node.data.get(chars)
  78. return get_key(prefix, node)
  79. class TrieTokenizer(Trie):
  80. """
  81. word_split based on trie-tree
  82. """
  83. def __init__(self, dict_path):
  84. super(TrieTokenizer, self).__init__()
  85. self.dict_path = dict_path
  86. self.create_trie_tree()
  87. def load_dict(self):
  88. words = []
  89. with open(self.dict_path, mode='r', encoding='utf-8') as file:
  90. for line in file:
  91. words.append(line.strip().split('\t')[0].encode(
  92. 'utf-8').decode('utf-8-sig'))
  93. return words
  94. def create_trie_tree(self):
  95. words = self.load_dict()
  96. for word in words:
  97. self.insert(word)
  98. def mine_tree(self, tree, sentence, trace_index):
  99. if trace_index <= (len(sentence) - 1):
  100. if sentence[trace_index] in tree.data:
  101. trace_index = trace_index + 1
  102. trace_index = self.mine_tree(
  103. tree.data[sentence[trace_index - 1]], sentence,
  104. trace_index)
  105. return trace_index
  106. def tokenize(self, sentence):
  107. tokens = []
  108. sentence_len = len(sentence)
  109. while sentence_len != 0:
  110. trace_index = 0
  111. trace_index = self.mine_tree(self.root, sentence, trace_index)
  112. if trace_index == 0:
  113. tokens.append(sentence[0:1])
  114. sentence = sentence[1:len(sentence)]
  115. sentence_len = len(sentence)
  116. else:
  117. tokens.append(sentence[0:trace_index])
  118. sentence = sentence[trace_index:len(sentence)]
  119. sentence_len = len(sentence)
  120. return tokens
  121. def combine(self, token_list):
  122. flag = 0
  123. output = []
  124. temp = []
  125. for i in token_list:
  126. if len(i) != 1:
  127. if flag == 0:
  128. output.append(i[::])
  129. else:
  130. output.append(''.join(temp))
  131. output.append(i[::])
  132. temp = []
  133. flag = 0
  134. else:
  135. if flag == 0:
  136. temp.append(i)
  137. flag = 1
  138. else:
  139. temp.append(i)
  140. return output
  141. class Text2Phone:
  142. def __init__(self, phone_dict_path):
  143. self.trie_cws = TrieTokenizer(phone_dict_path)
  144. self.phone_map = self.get_phone_map(phone_dict_path)
  145. def get_phone_map(self, phone_dict_path):
  146. phone_map = dict()
  147. with open(phone_dict_path, 'r') as phone_map_file_reader:
  148. for line in phone_map_file_reader:
  149. key, phone_series = line.strip().split('\t')
  150. if key not in phone_map:
  151. phone_map[key] = phone_series
  152. return phone_map
  153. def trans(self, text):
  154. text = normalize_chinese_number(text)
  155. tokens = self.trie_cws.tokenize(text)
  156. phones = []
  157. for word in tokens:
  158. if word in self.phone_map:
  159. phones.append(self.phone_map[word])
  160. elif len(word) > 1:
  161. for char in word:
  162. if char in self.phone_map:
  163. phones.append(self.phone_map[char])
  164. return ' '.join(phones)