imdb.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import re
  16. import string
  17. import tarfile
  18. import numpy as np
  19. from paddle.dataset.common import _check_exists_and_download
  20. from paddle.io import Dataset
  21. __all__ = []
  22. URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
  23. MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
  24. class Imdb(Dataset):
  25. """
  26. Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
  27. Args:
  28. data_file(str): path to data tar file, can be set None if
  29. :attr:`download` is True. Default None
  30. mode(str): 'train' 'test' mode. Default 'train'.
  31. cutoff(int): cutoff number for building word dictionary. Default 150.
  32. download(bool): whether to download dataset automatically if
  33. :attr:`data_file` is not set. Default True
  34. Returns:
  35. Dataset: instance of IMDB dataset
  36. Examples:
  37. .. code-block:: python
  38. >>> # doctest: +TIMEOUT(75)
  39. >>> import paddle
  40. >>> from paddle.text.datasets import Imdb
  41. >>> class SimpleNet(paddle.nn.Layer):
  42. ... def __init__(self):
  43. ... super().__init__()
  44. ...
  45. ... def forward(self, doc, label):
  46. ... return paddle.sum(doc), label
  47. >>> imdb = Imdb(mode='train')
  48. >>> for i in range(10):
  49. ... doc, label = imdb[i]
  50. ... doc = paddle.to_tensor(doc)
  51. ... label = paddle.to_tensor(label)
  52. ...
  53. ... model = SimpleNet()
  54. ... image, label = model(doc, label)
  55. ... print(doc.shape, label.shape)
  56. [121] [1]
  57. [115] [1]
  58. [386] [1]
  59. [471] [1]
  60. [585] [1]
  61. [206] [1]
  62. [221] [1]
  63. [324] [1]
  64. [166] [1]
  65. [598] [1]
  66. """
  67. def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
  68. assert mode.lower() in [
  69. 'train',
  70. 'test',
  71. ], f"mode should be 'train', 'test', but got {mode}"
  72. self.mode = mode.lower()
  73. self.data_file = data_file
  74. if self.data_file is None:
  75. assert (
  76. download
  77. ), "data_file is not set and downloading automatically is disabled"
  78. self.data_file = _check_exists_and_download(
  79. data_file, URL, MD5, 'imdb', download
  80. )
  81. # Build a word dictionary from the corpus
  82. self.word_idx = self._build_work_dict(cutoff)
  83. # read dataset into memory
  84. self._load_anno()
  85. def _build_work_dict(self, cutoff):
  86. word_freq = collections.defaultdict(int)
  87. pattern = re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
  88. for doc in self._tokenize(pattern):
  89. for word in doc:
  90. word_freq[word] += 1
  91. # Not sure if we should prune less-frequent words here.
  92. word_freq = [x for x in word_freq.items() if x[1] > cutoff]
  93. dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
  94. words, _ = list(zip(*dictionary))
  95. word_idx = dict(list(zip(words, range(len(words)))))
  96. word_idx['<unk>'] = len(words)
  97. return word_idx
  98. def _tokenize(self, pattern):
  99. data = []
  100. with tarfile.open(self.data_file) as tarf:
  101. tf = tarf.next()
  102. while tf is not None:
  103. if bool(pattern.match(tf.name)):
  104. # newline and punctuations removal and ad-hoc tokenization.
  105. data.append(
  106. tarf.extractfile(tf)
  107. .read()
  108. .rstrip(b'\n\r')
  109. .translate(None, string.punctuation.encode('latin-1'))
  110. .lower()
  111. .split()
  112. )
  113. tf = tarf.next()
  114. return data
  115. def _load_anno(self):
  116. pos_pattern = re.compile(fr"aclImdb/{self.mode}/pos/.*\.txt$")
  117. neg_pattern = re.compile(fr"aclImdb/{self.mode}/neg/.*\.txt$")
  118. UNK = self.word_idx['<unk>']
  119. self.docs = []
  120. self.labels = []
  121. for doc in self._tokenize(pos_pattern):
  122. self.docs.append([self.word_idx.get(w, UNK) for w in doc])
  123. self.labels.append(0)
  124. for doc in self._tokenize(neg_pattern):
  125. self.docs.append([self.word_idx.get(w, UNK) for w in doc])
  126. self.labels.append(1)
  127. def __getitem__(self, idx):
  128. return (np.array(self.docs[idx]), np.array([self.labels[idx]]))
  129. def __len__(self):
  130. return len(self.docs)