wmt16.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import tarfile
  16. from collections import defaultdict
  17. import numpy as np
  18. import paddle
  19. from paddle.dataset.common import _check_exists_and_download
  20. from paddle.io import Dataset
  21. __all__ = []
  22. DATA_URL = "http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz"
  23. DATA_MD5 = "0c38be43600334966403524a40dcd81e"
  24. TOTAL_EN_WORDS = 11250
  25. TOTAL_DE_WORDS = 19220
  26. START_MARK = "<s>"
  27. END_MARK = "<e>"
  28. UNK_MARK = "<unk>"
  29. class WMT16(Dataset):
  30. """
  31. Implementation of `WMT16 <http://www.statmt.org/wmt16/>`_ test dataset.
  32. ACL2016 Multimodal Machine Translation. Please see this website for more
  33. details: http://www.statmt.org/wmt16/multimodal-task.html#task1
  34. If you use the dataset created for your task, please cite the following paper:
  35. Multi30K: Multilingual English-German Image Descriptions.
  36. .. code-block:: text
  37. @article{elliott-EtAl:2016:VL16,
  38. author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
  39. title = {Multi30K: Multilingual English-German Image Descriptions},
  40. booktitle = {Proceedings of the 6th Workshop on Vision and Language},
  41. year = {2016},
  42. pages = {70--74},
  43. year = 2016
  44. }
  45. Args:
  46. data_file(str): path to data tar file, can be set None if
  47. :attr:`download` is True. Default None.
  48. mode(str): 'train', 'test' or 'val'. Default 'train'.
  49. src_dict_size(int): word dictionary size for source language word. Default -1.
  50. trg_dict_size(int): word dictionary size for target language word. Default -1.
  51. lang(str): source language, 'en' or 'de'. Default 'en'.
  52. download(bool): whether to download dataset automatically if
  53. :attr:`data_file` is not set. Default True.
  54. Returns:
  55. Dataset: Instance of WMT16 dataset. The instance of dataset has 3 fields:
  56. - src_ids (np.array) - The sequence of token ids of source language.
  57. - trg_ids (np.array) - The sequence of token ids of target language.
  58. - trg_ids_next (np.array) - The next sequence of token ids of target language.
  59. Examples:
  60. .. code-block:: python
  61. >>> import paddle
  62. >>> from paddle.text.datasets import WMT16
  63. >>> class SimpleNet(paddle.nn.Layer):
  64. ... def __init__(self):
  65. ... super().__init__()
  66. ...
  67. ... def forward(self, src_ids, trg_ids, trg_ids_next):
  68. ... return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
  69. >>> wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
  70. >>> for i in range(10):
  71. ... src_ids, trg_ids, trg_ids_next = wmt16[i]
  72. ... src_ids = paddle.to_tensor(src_ids)
  73. ... trg_ids = paddle.to_tensor(trg_ids)
  74. ... trg_ids_next = paddle.to_tensor(trg_ids_next)
  75. ...
  76. ... model = SimpleNet()
  77. ... src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
  78. ... print(src_ids.item(), trg_ids.item(), trg_ids_next.item())
  79. 89 32 33
  80. 79 18 19
  81. 55 26 27
  82. 147 36 37
  83. 106 22 23
  84. 135 50 51
  85. 54 43 44
  86. 217 30 31
  87. 146 51 52
  88. 55 24 25
  89. """
  90. def __init__(
  91. self,
  92. data_file=None,
  93. mode='train',
  94. src_dict_size=-1,
  95. trg_dict_size=-1,
  96. lang='en',
  97. download=True,
  98. ):
  99. assert mode.lower() in [
  100. 'train',
  101. 'test',
  102. 'val',
  103. ], f"mode should be 'train', 'test' or 'val', but got {mode}"
  104. self.mode = mode.lower()
  105. self.data_file = data_file
  106. if self.data_file is None:
  107. assert (
  108. download
  109. ), "data_file is not set and downloading automatically is disabled"
  110. self.data_file = _check_exists_and_download(
  111. data_file, DATA_URL, DATA_MD5, 'wmt16', download
  112. )
  113. self.lang = lang
  114. assert src_dict_size > 0, "dict_size should be set as positive number"
  115. assert trg_dict_size > 0, "dict_size should be set as positive number"
  116. self.src_dict_size = min(
  117. src_dict_size, (TOTAL_EN_WORDS if lang == "en" else TOTAL_DE_WORDS)
  118. )
  119. self.trg_dict_size = min(
  120. trg_dict_size, (TOTAL_DE_WORDS if lang == "en" else TOTAL_EN_WORDS)
  121. )
  122. # load source and target word dict
  123. self.src_dict = self._load_dict(lang, src_dict_size)
  124. self.trg_dict = self._load_dict(
  125. "de" if lang == "en" else "en", trg_dict_size
  126. )
  127. # load data
  128. self.data = self._load_data()
  129. def _load_dict(self, lang, dict_size, reverse=False):
  130. dict_path = os.path.join(
  131. paddle.dataset.common.DATA_HOME,
  132. "wmt16/%s_%d.dict" % (lang, dict_size),
  133. )
  134. dict_found = False
  135. if os.path.exists(dict_path):
  136. with open(dict_path, "rb") as d:
  137. dict_found = len(d.readlines()) == dict_size
  138. if not dict_found:
  139. self._build_dict(dict_path, dict_size, lang)
  140. word_dict = {}
  141. with open(dict_path, "rb") as fdict:
  142. for idx, line in enumerate(fdict):
  143. if reverse:
  144. word_dict[idx] = line.strip().decode()
  145. else:
  146. word_dict[line.strip().decode()] = idx
  147. return word_dict
  148. def _build_dict(self, dict_path, dict_size, lang):
  149. word_dict = defaultdict(int)
  150. with tarfile.open(self.data_file, mode="r") as f:
  151. for line in f.extractfile("wmt16/train"):
  152. line = line.decode()
  153. line_split = line.strip().split("\t")
  154. if len(line_split) != 2:
  155. continue
  156. sen = line_split[0] if self.lang == "en" else line_split[1]
  157. for w in sen.split():
  158. word_dict[w] += 1
  159. with open(dict_path, "wb") as fout:
  160. fout.write((f"{START_MARK}\n{END_MARK}\n{UNK_MARK}\n").encode())
  161. for idx, word in enumerate(
  162. sorted(word_dict.items(), key=lambda x: x[1], reverse=True)
  163. ):
  164. if idx + 3 == dict_size:
  165. break
  166. fout.write(word[0].encode())
  167. fout.write(b'\n')
  168. def _load_data(self):
  169. # the index for start mark, end mark, and unk are the same in source
  170. # language and target language. Here uses the source language
  171. # dictionary to determine their indices.
  172. start_id = self.src_dict[START_MARK]
  173. end_id = self.src_dict[END_MARK]
  174. unk_id = self.src_dict[UNK_MARK]
  175. src_col = 0 if self.lang == "en" else 1
  176. trg_col = 1 - src_col
  177. self.src_ids = []
  178. self.trg_ids = []
  179. self.trg_ids_next = []
  180. with tarfile.open(self.data_file, mode="r") as f:
  181. for line in f.extractfile(f"wmt16/{self.mode}"):
  182. line = line.decode()
  183. line_split = line.strip().split("\t")
  184. if len(line_split) != 2:
  185. continue
  186. src_words = line_split[src_col].split()
  187. src_ids = (
  188. [start_id]
  189. + [self.src_dict.get(w, unk_id) for w in src_words]
  190. + [end_id]
  191. )
  192. trg_words = line_split[trg_col].split()
  193. trg_ids = [self.trg_dict.get(w, unk_id) for w in trg_words]
  194. trg_ids_next = trg_ids + [end_id]
  195. trg_ids = [start_id] + trg_ids
  196. self.src_ids.append(src_ids)
  197. self.trg_ids.append(trg_ids)
  198. self.trg_ids_next.append(trg_ids_next)
  199. def __getitem__(self, idx):
  200. return (
  201. np.array(self.src_ids[idx]),
  202. np.array(self.trg_ids[idx]),
  203. np.array(self.trg_ids_next[idx]),
  204. )
  205. def __len__(self):
  206. return len(self.src_ids)
  207. def get_dict(self, lang, reverse=False):
  208. """
  209. return the word dictionary for the specified language.
  210. Args:
  211. lang(string): A string indicating which language is the source
  212. language. Available options are: "en" for English
  213. and "de" for Germany.
  214. reverse(bool): If reverse is set to False, the returned python
  215. dictionary will use word as key and use index as value.
  216. If reverse is set to True, the returned python
  217. dictionary will use index as key and word as value.
  218. Returns:
  219. dict: The word dictionary for the specific language.
  220. Examples:
  221. .. code-block:: python
  222. >>> from paddle.text.datasets import WMT16
  223. >>> wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
  224. >>> en_dict = wmt16.get_dict('en')
  225. """
  226. dict_size = (
  227. self.src_dict_size if lang == self.lang else self.trg_dict_size
  228. )
  229. dict_path = os.path.join(
  230. paddle.dataset.common.DATA_HOME,
  231. "wmt16/%s_%d.dict" % (lang, dict_size),
  232. )
  233. assert os.path.exists(dict_path), "Word dictionary does not exist. "
  234. "Please invoke paddle.dataset.wmt16.train/test/validation first "
  235. "to build the dictionary."
  236. return self._load_dict(lang, dict_size)