imdb.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. IMDB dataset.
  16. This module downloads IMDB dataset from
  17. http://ai.stanford.edu/%7Eamaas/data/sentiment/. This dataset contains a set
  18. of 25,000 highly polar movie reviews for training, and 25,000 for testing.
  19. Besides, this module also provides API for building dictionary.
  20. """
  21. import collections
  22. import re
  23. import string
  24. import tarfile
  25. import paddle.dataset.common
  26. from paddle.utils import deprecated
  27. __all__ = []
  28. # URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
  29. URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
  30. MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
  31. def tokenize(pattern):
  32. """
  33. Read files that match the given pattern. Tokenize and yield each file.
  34. """
  35. with tarfile.open(paddle.dataset.common.download(URL, 'imdb', MD5)) as tarf:
  36. # Note that we should use tarfile.next(), which does
  37. # sequential access of member files, other than
  38. # tarfile.extractfile, which does random access and might
  39. # destroy hard disks.
  40. tf = tarf.next()
  41. while tf is not None:
  42. if bool(pattern.match(tf.name)):
  43. # newline and punctuations removal and ad-hoc tokenization.
  44. yield tarf.extractfile(tf).read().rstrip(b'\n\r').translate(
  45. None, string.punctuation.encode('latin-1')
  46. ).lower().split()
  47. tf = tarf.next()
  48. def build_dict(pattern, cutoff):
  49. """
  50. Build a word dictionary from the corpus. Keys of the dictionary are words,
  51. and values are zero-based IDs of these words.
  52. """
  53. word_freq = collections.defaultdict(int)
  54. for doc in tokenize(pattern):
  55. for word in doc:
  56. word_freq[word] += 1
  57. # Not sure if we should prune less-frequent words here.
  58. word_freq = [x for x in word_freq.items() if x[1] > cutoff]
  59. dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
  60. words, _ = list(zip(*dictionary))
  61. word_idx = dict(list(zip(words, range(len(words)))))
  62. word_idx['<unk>'] = len(words)
  63. return word_idx
  64. @deprecated(
  65. since="2.0.0",
  66. update_to="paddle.text.datasets.Imdb",
  67. level=1,
  68. reason="Please use new dataset API which supports paddle.io.DataLoader",
  69. )
  70. def reader_creator(pos_pattern, neg_pattern, word_idx):
  71. UNK = word_idx['<unk>']
  72. INS = []
  73. def load(pattern, out, label):
  74. for doc in tokenize(pattern):
  75. out.append(([word_idx.get(w, UNK) for w in doc], label))
  76. load(pos_pattern, INS, 0)
  77. load(neg_pattern, INS, 1)
  78. def reader():
  79. yield from INS
  80. return reader
  81. @deprecated(
  82. since="2.0.0",
  83. update_to="paddle.text.datasets.Imdb",
  84. level=1,
  85. reason="Please use new dataset API which supports paddle.io.DataLoader",
  86. )
  87. def train(word_idx):
  88. """
  89. IMDB training set creator.
  90. It returns a reader creator, each sample in the reader is an zero-based ID
  91. sequence and label in [0, 1].
  92. :param word_idx: word dictionary
  93. :type word_idx: dict
  94. :return: Training reader creator
  95. :rtype: callable
  96. """
  97. return reader_creator(
  98. re.compile(r"aclImdb/train/pos/.*\.txt$"),
  99. re.compile(r"aclImdb/train/neg/.*\.txt$"),
  100. word_idx,
  101. )
  102. @deprecated(
  103. since="2.0.0",
  104. update_to="paddle.text.datasets.Imdb",
  105. level=1,
  106. reason="Please use new dataset API which supports paddle.io.DataLoader",
  107. )
  108. def test(word_idx):
  109. """
  110. IMDB test set creator.
  111. It returns a reader creator, each sample in the reader is an zero-based ID
  112. sequence and label in [0, 1].
  113. :param word_idx: word dictionary
  114. :type word_idx: dict
  115. :return: Test reader creator
  116. :rtype: callable
  117. """
  118. return reader_creator(
  119. re.compile(r"aclImdb/test/pos/.*\.txt$"),
  120. re.compile(r"aclImdb/test/neg/.*\.txt$"),
  121. word_idx,
  122. )
  123. @deprecated(
  124. since="2.0.0",
  125. update_to="paddle.text.datasets.Imdb",
  126. level=1,
  127. reason="Please use new dataset API which supports paddle.io.DataLoader",
  128. )
  129. def word_dict():
  130. """
  131. Build a word dictionary from the corpus.
  132. :return: Word dictionary
  133. :rtype: dict
  134. """
  135. return build_dict(
  136. re.compile(r"aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150
  137. )
  138. @deprecated(
  139. since="2.0.0",
  140. update_to="paddle.text.datasets.Imdb",
  141. level=1,
  142. reason="Please use new dataset API which supports paddle.io.DataLoader",
  143. )
  144. def fetch():
  145. paddle.dataset.common.download(URL, 'imdb', MD5)