| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- ACL2016 Multimodal Machine Translation. Please see this website for more
- details: http://www.statmt.org/wmt16/multimodal-task.html#task1
- If you use the dataset created for your task, please cite the following paper:
- Multi30K: Multilingual English-German Image Descriptions.
- @article{elliott-EtAl:2016:VL16,
- author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
- title = {Multi30K: Multilingual English-German Image Descriptions},
- booktitle = {Proceedings of the 6th Workshop on Vision and Language},
- year = {2016},
- pages = {70--74},
- year = 2016
- }
- """
- import os
- import tarfile
- from collections import defaultdict
- import paddle
- from paddle.utils import deprecated
- __all__ = []
- DATA_URL = "http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz"
- DATA_MD5 = "0c38be43600334966403524a40dcd81e"
- TOTAL_EN_WORDS = 11250
- TOTAL_DE_WORDS = 19220
- START_MARK = "<s>"
- END_MARK = "<e>"
- UNK_MARK = "<unk>"
- def __build_dict(tar_file, dict_size, save_path, lang):
- word_dict = defaultdict(int)
- with tarfile.open(tar_file, mode="r") as f:
- for line in f.extractfile("wmt16/train"):
- line = line.decode()
- line_split = line.strip().split("\t")
- if len(line_split) != 2:
- continue
- sen = line_split[0] if lang == "en" else line_split[1]
- for w in sen.split():
- word_dict[w] += 1
- with open(save_path, "wb") as fout:
- fout.write((f"{START_MARK}\n{END_MARK}\n{UNK_MARK}\n").encode())
- for idx, word in enumerate(
- sorted(word_dict.items(), key=lambda x: x[1], reverse=True)
- ):
- if idx + 3 == dict_size:
- break
- fout.write(word[0].encode())
- fout.write(b'\n')
- def __load_dict(tar_file, dict_size, lang, reverse=False):
- dict_path = os.path.join(
- paddle.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size)
- )
- if not os.path.exists(dict_path) or (
- len(open(dict_path, "rb").readlines()) != dict_size
- ):
- __build_dict(tar_file, dict_size, dict_path, lang)
- word_dict = {}
- with open(dict_path, "rb") as fdict:
- for idx, line in enumerate(fdict):
- if reverse:
- word_dict[idx] = line.strip().decode()
- else:
- word_dict[line.strip().decode()] = idx
- return word_dict
- def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
- src_dict_size = min(
- src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else TOTAL_DE_WORDS)
- )
- trg_dict_size = min(
- trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else TOTAL_EN_WORDS)
- )
- return src_dict_size, trg_dict_size
- def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
- def reader():
- src_dict = __load_dict(tar_file, src_dict_size, src_lang)
- trg_dict = __load_dict(
- tar_file, trg_dict_size, ("de" if src_lang == "en" else "en")
- )
- # the index for start mark, end mark, and unk are the same in source
- # language and target language. Here uses the source language
- # dictionary to determine their indices.
- start_id = src_dict[START_MARK]
- end_id = src_dict[END_MARK]
- unk_id = src_dict[UNK_MARK]
- src_col = 0 if src_lang == "en" else 1
- trg_col = 1 - src_col
- with tarfile.open(tar_file, mode="r") as f:
- for line in f.extractfile(file_name):
- line = line.decode()
- line_split = line.strip().split("\t")
- if len(line_split) != 2:
- continue
- src_words = line_split[src_col].split()
- src_ids = (
- [start_id]
- + [src_dict.get(w, unk_id) for w in src_words]
- + [end_id]
- )
- trg_words = line_split[trg_col].split()
- trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
- trg_ids_next = trg_ids + [end_id]
- trg_ids = [start_id] + trg_ids
- yield src_ids, trg_ids, trg_ids_next
- return reader
- @deprecated(
- since="2.0.0",
- update_to="paddle.text.datasets.WMT16",
- level=1,
- reason="Please use new dataset API which supports paddle.io.DataLoader",
- )
- def train(src_dict_size, trg_dict_size, src_lang="en"):
- """
- WMT16 train set reader.
- This function returns the reader for train data. Each sample the reader
- returns is made up of three fields: the source language word index sequence,
- target language word index sequence and next word index sequence.
- NOTE:
- The original like for training data is:
- http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz
- paddle.dataset.wmt16 provides a tokenized version of the original dataset by
- using moses's tokenization script:
- https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
- Args:
- src_dict_size(int): Size of the source language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- trg_dict_size(int): Size of the target language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- src_lang(string): A string indicating which language is the source
- language. Available options are: "en" for English
- and "de" for Germany.
- Returns:
- callable: The train reader.
- """
- if src_lang not in ["en", "de"]:
- raise ValueError(
- "An error language type. Only support: "
- "en (for English); de(for Germany)."
- )
- src_dict_size, trg_dict_size = __get_dict_size(
- src_dict_size, trg_dict_size, src_lang
- )
- return reader_creator(
- tar_file=paddle.dataset.common.download(
- DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"
- ),
- file_name="wmt16/train",
- src_dict_size=src_dict_size,
- trg_dict_size=trg_dict_size,
- src_lang=src_lang,
- )
- @deprecated(
- since="2.0.0",
- update_to="paddle.text.datasets.WMT16",
- level=1,
- reason="Please use new dataset API which supports paddle.io.DataLoader",
- )
- def test(src_dict_size, trg_dict_size, src_lang="en"):
- """
- WMT16 test set reader.
- This function returns the reader for test data. Each sample the reader
- returns is made up of three fields: the source language word index sequence,
- target language word index sequence and next word index sequence.
- NOTE:
- The original like for test data is:
- http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz
- paddle.dataset.wmt16 provides a tokenized version of the original dataset by
- using moses's tokenization script:
- https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
- Args:
- src_dict_size(int): Size of the source language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- trg_dict_size(int): Size of the target language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- src_lang(string): A string indicating which language is the source
- language. Available options are: "en" for English
- and "de" for Germany.
- Returns:
- callable: The test reader.
- """
- if src_lang not in ["en", "de"]:
- raise ValueError(
- "An error language type. "
- "Only support: en (for English); de(for Germany)."
- )
- src_dict_size, trg_dict_size = __get_dict_size(
- src_dict_size, trg_dict_size, src_lang
- )
- return reader_creator(
- tar_file=paddle.dataset.common.download(
- DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"
- ),
- file_name="wmt16/test",
- src_dict_size=src_dict_size,
- trg_dict_size=trg_dict_size,
- src_lang=src_lang,
- )
- @deprecated(
- since="2.0.0",
- update_to="paddle.text.datasets.WMT16",
- level=1,
- reason="Please use new dataset API which supports paddle.io.DataLoader",
- )
- def validation(src_dict_size, trg_dict_size, src_lang="en"):
- """
- WMT16 validation set reader.
- This function returns the reader for validation data. Each sample the reader
- returns is made up of three fields: the source language word index sequence,
- target language word index sequence and next word index sequence.
- NOTE:
- The original like for validation data is:
- http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz
- paddle.dataset.wmt16 provides a tokenized version of the original dataset by
- using moses's tokenization script:
- https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
- Args:
- src_dict_size(int): Size of the source language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- trg_dict_size(int): Size of the target language dictionary. Three
- special tokens will be added into the dictionary:
- <s> for start mark, <e> for end mark, and <unk> for
- unknown word.
- src_lang(string): A string indicating which language is the source
- language. Available options are: "en" for English
- and "de" for Germany.
- Returns:
- callable: The validation reader.
- """
- if src_lang not in ["en", "de"]:
- raise ValueError(
- "An error language type. "
- "Only support: en (for English); de(for Germany)."
- )
- src_dict_size, trg_dict_size = __get_dict_size(
- src_dict_size, trg_dict_size, src_lang
- )
- return reader_creator(
- tar_file=paddle.dataset.common.download(
- DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"
- ),
- file_name="wmt16/val",
- src_dict_size=src_dict_size,
- trg_dict_size=trg_dict_size,
- src_lang=src_lang,
- )
- @deprecated(
- since="2.0.0",
- update_to="paddle.text.datasets.WMT16",
- level=1,
- reason="Please use new dataset API which supports paddle.io.DataLoader",
- )
- def get_dict(lang, dict_size, reverse=False):
- """
- return the word dictionary for the specified language.
- Args:
- lang(string): A string indicating which language is the source
- language. Available options are: "en" for English
- and "de" for Germany.
- dict_size(int): Size of the specified language dictionary.
- reverse(bool): If reverse is set to False, the returned python
- dictionary will use word as key and use index as value.
- If reverse is set to True, the returned python
- dictionary will use index as key and word as value.
- Returns:
- dict: The word dictionary for the specific language.
- """
- if lang == "en":
- dict_size = min(dict_size, TOTAL_EN_WORDS)
- else:
- dict_size = min(dict_size, TOTAL_DE_WORDS)
- dict_path = os.path.join(
- paddle.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size)
- )
- assert os.path.exists(dict_path), "Word dictionary does not exist. "
- "Please invoke paddle.dataset.wmt16.train/test/validation first "
- "to build the dictionary."
- tar_file = os.path.join(paddle.dataset.common.DATA_HOME, "wmt16.tar.gz")
- return __load_dict(tar_file, dict_size, lang, reverse)
- @deprecated(
- since="2.0.0",
- update_to="paddle.text.datasets.WMT16",
- level=1,
- reason="Please use new dataset API which supports paddle.io.DataLoader",
- )
- def fetch():
- """download the entire dataset."""
- paddle.v4.dataset.common.download(
- DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"
- )
|