| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import collections
- import re
- import string
- import tarfile
- import numpy as np
- from paddle.dataset.common import _check_exists_and_download
- from paddle.io import Dataset
- __all__ = []
- URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
- MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
- class Imdb(Dataset):
- """
- Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
- Args:
- data_file(str): path to data tar file, can be set None if
- :attr:`download` is True. Default None
- mode(str): 'train' 'test' mode. Default 'train'.
- cutoff(int): cutoff number for building word dictionary. Default 150.
- download(bool): whether to download dataset automatically if
- :attr:`data_file` is not set. Default True
- Returns:
- Dataset: instance of IMDB dataset
- Examples:
- .. code-block:: python
- >>> # doctest: +TIMEOUT(75)
- >>> import paddle
- >>> from paddle.text.datasets import Imdb
- >>> class SimpleNet(paddle.nn.Layer):
- ... def __init__(self):
- ... super().__init__()
- ...
- ... def forward(self, doc, label):
- ... return paddle.sum(doc), label
- >>> imdb = Imdb(mode='train')
- >>> for i in range(10):
- ... doc, label = imdb[i]
- ... doc = paddle.to_tensor(doc)
- ... label = paddle.to_tensor(label)
- ...
- ... model = SimpleNet()
- ... image, label = model(doc, label)
- ... print(doc.shape, label.shape)
- [121] [1]
- [115] [1]
- [386] [1]
- [471] [1]
- [585] [1]
- [206] [1]
- [221] [1]
- [324] [1]
- [166] [1]
- [598] [1]
- """
- def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
- assert mode.lower() in [
- 'train',
- 'test',
- ], f"mode should be 'train', 'test', but got {mode}"
- self.mode = mode.lower()
- self.data_file = data_file
- if self.data_file is None:
- assert (
- download
- ), "data_file is not set and downloading automatically is disabled"
- self.data_file = _check_exists_and_download(
- data_file, URL, MD5, 'imdb', download
- )
- # Build a word dictionary from the corpus
- self.word_idx = self._build_work_dict(cutoff)
- # read dataset into memory
- self._load_anno()
- def _build_work_dict(self, cutoff):
- word_freq = collections.defaultdict(int)
- pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
- for doc in self._tokenize(pattern):
- for word in doc:
- word_freq[word] += 1
- # Not sure if we should prune less-frequent words here.
- word_freq = [x for x in word_freq.items() if x[1] > cutoff]
- dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
- words, _ = list(zip(*dictionary))
- word_idx = dict(list(zip(words, range(len(words)))))
- word_idx['<unk>'] = len(words)
- return word_idx
- def _tokenize(self, pattern):
- data = []
- with tarfile.open(self.data_file) as tarf:
- tf = tarf.next()
- while tf is not None:
- if bool(pattern.match(tf.name)):
- # newline and punctuations removal and ad-hoc tokenization.
- data.append(
- tarf.extractfile(tf)
- .read()
- .rstrip(b'\n\r')
- .translate(None, string.punctuation.encode('latin-1'))
- .lower()
- .split()
- )
- tf = tarf.next()
- return data
- def _load_anno(self):
- pos_pattern = re.compile(fr"aclImdb/{self.mode}/pos/.*\.txt$")
- neg_pattern = re.compile(fr"aclImdb/{self.mode}/neg/.*\.txt$")
- UNK = self.word_idx['<unk>']
- self.docs = []
- self.labels = []
- for doc in self._tokenize(pos_pattern):
- self.docs.append([self.word_idx.get(w, UNK) for w in doc])
- self.labels.append(0)
- for doc in self._tokenize(neg_pattern):
- self.docs.append([self.word_idx.get(w, UNK) for w in doc])
- self.labels.append(1)
- def __getitem__(self, idx):
- return (np.array(self.docs[idx]), np.array([self.labels[idx]]))
- def __len__(self):
- return len(self.docs)
|