wmt14.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. WMT14 dataset.
  16. The original WMT14 dataset is too large and a small set of data for set is
  17. provided. This module will download dataset from
  18. http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
  19. parse training set and test set into paddle reader creators.
  20. """
  21. import tarfile
  22. import paddle.dataset.common
  23. from paddle.utils import deprecated
  24. __all__ = []
  25. URL_DEV_TEST = (
  26. 'http://www-lium.univ-lemans.fr/~schwenk/'
  27. 'cslm_joint_paper/data/dev+test.tgz'
  28. )
  29. MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
  30. # this is a small set of data for test. The original data is too large and
  31. # will be add later.
  32. URL_TRAIN = 'http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz'
  33. MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
  34. # BLEU of this trained model is 26.92
  35. URL_MODEL = 'http://paddlemodels.bj.bcebos.com/wmt%2Fwmt14.tgz'
  36. MD5_MODEL = '0cb4a5366189b6acba876491c8724fa3'
  37. START = "<s>"
  38. END = "<e>"
  39. UNK = "<unk>"
  40. UNK_IDX = 2
  41. def __read_to_dict(tar_file, dict_size):
  42. def __to_dict(fd, size):
  43. out_dict = {}
  44. for line_count, line in enumerate(fd):
  45. if line_count < size:
  46. out_dict[line.strip().decode()] = line_count
  47. else:
  48. break
  49. return out_dict
  50. with tarfile.open(tar_file, mode='r') as f:
  51. names = [
  52. each_item.name
  53. for each_item in f
  54. if each_item.name.endswith("src.dict")
  55. ]
  56. assert len(names) == 1
  57. src_dict = __to_dict(f.extractfile(names[0]), dict_size)
  58. names = [
  59. each_item.name
  60. for each_item in f
  61. if each_item.name.endswith("trg.dict")
  62. ]
  63. assert len(names) == 1
  64. trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
  65. return src_dict, trg_dict
  66. def reader_creator(tar_file, file_name, dict_size):
  67. def reader():
  68. src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
  69. with tarfile.open(tar_file, mode='r') as f:
  70. names = [
  71. each_item.name
  72. for each_item in f
  73. if each_item.name.endswith(file_name)
  74. ]
  75. for name in names:
  76. for line in f.extractfile(name):
  77. line = line.decode()
  78. line_split = line.strip().split('\t')
  79. if len(line_split) != 2:
  80. continue
  81. src_seq = line_split[0] # one source sequence
  82. src_words = src_seq.split()
  83. src_ids = [
  84. src_dict.get(w, UNK_IDX)
  85. for w in [START] + src_words + [END]
  86. ]
  87. trg_seq = line_split[1] # one target sequence
  88. trg_words = trg_seq.split()
  89. trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
  90. # remove sequence whose length > 80 in training mode
  91. if len(src_ids) > 80 or len(trg_ids) > 80:
  92. continue
  93. trg_ids_next = trg_ids + [trg_dict[END]]
  94. trg_ids = [trg_dict[START]] + trg_ids
  95. yield src_ids, trg_ids, trg_ids_next
  96. return reader
  97. @deprecated(
  98. since="2.0.0",
  99. update_to="paddle.text.datasets.WMT14",
  100. level=1,
  101. reason="Please use new dataset API which supports paddle.io.DataLoader",
  102. )
  103. def train(dict_size):
  104. """
  105. WMT14 training set creator.
  106. It returns a reader creator, each sample in the reader is source language
  107. word ID sequence, target language word ID sequence and next word ID
  108. sequence.
  109. :return: Training reader creator
  110. :rtype: callable
  111. """
  112. return reader_creator(
  113. paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
  114. 'train/train',
  115. dict_size,
  116. )
  117. @deprecated(
  118. since="2.0.0",
  119. update_to="paddle.text.datasets.WMT14",
  120. level=1,
  121. reason="Please use new dataset API which supports paddle.io.DataLoader",
  122. )
  123. def test(dict_size):
  124. """
  125. WMT14 test set creator.
  126. It returns a reader creator, each sample in the reader is source language
  127. word ID sequence, target language word ID sequence and next word ID
  128. sequence.
  129. :return: Test reader creator
  130. :rtype: callable
  131. """
  132. return reader_creator(
  133. paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
  134. 'test/test',
  135. dict_size,
  136. )
  137. @deprecated(
  138. since="2.0.0",
  139. update_to="paddle.text.datasets.WMT14",
  140. level=1,
  141. reason="Please use new dataset API which supports paddle.io.DataLoader",
  142. )
  143. def gen(dict_size):
  144. return reader_creator(
  145. paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
  146. 'gen/gen',
  147. dict_size,
  148. )
  149. @deprecated(
  150. since="2.0.0",
  151. update_to="paddle.text.datasets.WMT14",
  152. level=1,
  153. reason="Please use new dataset API which supports paddle.io.DataLoader",
  154. )
  155. def get_dict(dict_size, reverse=True):
  156. # if reverse = False, return dict = {'a':'001', 'b':'002', ...}
  157. # else reverse = true, return dict = {'001':'a', '002':'b', ...}
  158. tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
  159. src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
  160. if reverse:
  161. src_dict = {v: k for k, v in src_dict.items()}
  162. trg_dict = {v: k for k, v in trg_dict.items()}
  163. return src_dict, trg_dict
  164. @deprecated(
  165. since="2.0.0",
  166. update_to="paddle.text.datasets.WMT14",
  167. level=1,
  168. reason="Please use new dataset API which supports paddle.io.DataLoader",
  169. )
  170. def fetch():
  171. paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
  172. paddle.dataset.common.download(URL_MODEL, 'wmt14', MD5_MODEL)