wmt14.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import tarfile
  15. import numpy as np
  16. from paddle.dataset.common import _check_exists_and_download
  17. from paddle.io import Dataset
  18. __all__ = []
  19. URL_DEV_TEST = (
  20. 'http://www-lium.univ-lemans.fr/~schwenk/'
  21. 'cslm_joint_paper/data/dev+test.tgz'
  22. )
  23. MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
  24. # this is a small set of data for test. The original data is too large and
  25. # will be add later.
  26. URL_TRAIN = 'http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz'
  27. MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
  28. START = "<s>"
  29. END = "<e>"
  30. UNK = "<unk>"
  31. UNK_IDX = 2
  32. class WMT14(Dataset):
  33. """
  34. Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset.
  35. The original WMT14 dataset is too large and a small set of data for set is
  36. provided. This module will download dataset from
  37. http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz .
  38. Args:
  39. data_file(str): path to data tar file, can be set None if
  40. :attr:`download` is True. Default None
  41. mode(str): 'train', 'test' or 'gen'. Default 'train'
  42. dict_size(int): word dictionary size. Default -1.
  43. download(bool): whether to download dataset automatically if
  44. :attr:`data_file` is not set. Default True
  45. Returns:
  46. Dataset: Instance of WMT14 dataset
  47. - src_ids (np.array) - The sequence of token ids of source language.
  48. - trg_ids (np.array) - The sequence of token ids of target language.
  49. - trg_ids_next (np.array) - The next sequence of token ids of target language.
  50. Examples:
  51. .. code-block:: python
  52. >>> import paddle
  53. >>> from paddle.text.datasets import WMT14
  54. >>> class SimpleNet(paddle.nn.Layer):
  55. ... def __init__(self):
  56. ... super().__init__()
  57. ...
  58. ... def forward(self, src_ids, trg_ids, trg_ids_next):
  59. ... return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
  60. >>> wmt14 = WMT14(mode='train', dict_size=50)
  61. >>> for i in range(10):
  62. ... src_ids, trg_ids, trg_ids_next = wmt14[i]
  63. ... src_ids = paddle.to_tensor(src_ids)
  64. ... trg_ids = paddle.to_tensor(trg_ids)
  65. ... trg_ids_next = paddle.to_tensor(trg_ids_next)
  66. ...
  67. ... model = SimpleNet()
  68. ... src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
  69. ... print(src_ids.item(), trg_ids.item(), trg_ids_next.item())
  70. 91 38 39
  71. 123 81 82
  72. 556 229 230
  73. 182 26 27
  74. 447 242 243
  75. 116 110 111
  76. 403 288 289
  77. 258 221 222
  78. 136 34 35
  79. 281 136 137
  80. """
  81. def __init__(
  82. self, data_file=None, mode='train', dict_size=-1, download=True
  83. ):
  84. assert mode.lower() in [
  85. 'train',
  86. 'test',
  87. 'gen',
  88. ], f"mode should be 'train', 'test' or 'gen', but got {mode}"
  89. self.mode = mode.lower()
  90. self.data_file = data_file
  91. if self.data_file is None:
  92. assert (
  93. download
  94. ), "data_file is not set and downloading automatically is disabled"
  95. self.data_file = _check_exists_and_download(
  96. data_file, URL_TRAIN, MD5_TRAIN, 'wmt14', download
  97. )
  98. # read dataset into memory
  99. assert dict_size > 0, "dict_size should be set as positive number"
  100. self.dict_size = dict_size
  101. self._load_data()
  102. def _load_data(self):
  103. def __to_dict(fd, size):
  104. out_dict = {}
  105. for line_count, line in enumerate(fd):
  106. if line_count < size:
  107. out_dict[line.strip().decode()] = line_count
  108. else:
  109. break
  110. return out_dict
  111. self.src_ids = []
  112. self.trg_ids = []
  113. self.trg_ids_next = []
  114. with tarfile.open(self.data_file, mode='r') as f:
  115. names = [
  116. each_item.name
  117. for each_item in f
  118. if each_item.name.endswith("src.dict")
  119. ]
  120. assert len(names) == 1
  121. self.src_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
  122. names = [
  123. each_item.name
  124. for each_item in f
  125. if each_item.name.endswith("trg.dict")
  126. ]
  127. assert len(names) == 1
  128. self.trg_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
  129. file_name = f"{self.mode}/{self.mode}"
  130. names = [
  131. each_item.name
  132. for each_item in f
  133. if each_item.name.endswith(file_name)
  134. ]
  135. for name in names:
  136. for line in f.extractfile(name):
  137. line = line.decode()
  138. line_split = line.strip().split('\t')
  139. if len(line_split) != 2:
  140. continue
  141. src_seq = line_split[0] # one source sequence
  142. src_words = src_seq.split()
  143. src_ids = [
  144. self.src_dict.get(w, UNK_IDX)
  145. for w in [START] + src_words + [END]
  146. ]
  147. trg_seq = line_split[1] # one target sequence
  148. trg_words = trg_seq.split()
  149. trg_ids = [self.trg_dict.get(w, UNK_IDX) for w in trg_words]
  150. # remove sequence whose length > 80 in training mode
  151. if len(src_ids) > 80 or len(trg_ids) > 80:
  152. continue
  153. trg_ids_next = trg_ids + [self.trg_dict[END]]
  154. trg_ids = [self.trg_dict[START]] + trg_ids
  155. self.src_ids.append(src_ids)
  156. self.trg_ids.append(trg_ids)
  157. self.trg_ids_next.append(trg_ids_next)
  158. def __getitem__(self, idx):
  159. return (
  160. np.array(self.src_ids[idx]),
  161. np.array(self.trg_ids[idx]),
  162. np.array(self.trg_ids_next[idx]),
  163. )
  164. def __len__(self):
  165. return len(self.src_ids)
  166. def get_dict(self, reverse=False):
  167. """
  168. Get the source and target dictionary.
  169. Args:
  170. reverse (bool): wether to reverse key and value in dictionary,
  171. i.e. key: value to value: key.
  172. Returns:
  173. Two dictionaries, the source and target dictionary.
  174. Examples:
  175. .. code-block:: python
  176. >>> from paddle.text.datasets import WMT14
  177. >>> wmt14 = WMT14(mode='train', dict_size=50)
  178. >>> src_dict, trg_dict = wmt14.get_dict()
  179. """
  180. src_dict, trg_dict = self.src_dict, self.trg_dict
  181. if reverse:
  182. src_dict = {v: k for k, v in src_dict.items()}
  183. trg_dict = {v: k for k, v in trg_dict.items()}
  184. return src_dict, trg_dict