conll05.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import gzip
  15. import tarfile
  16. import numpy as np
  17. from paddle.dataset.common import _check_exists_and_download
  18. from paddle.io import Dataset
  19. __all__ = []
  20. DATA_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz'
  21. DATA_MD5 = '387719152ae52d60422c016e92a742fc'
  22. WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
  23. WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
  24. VERBDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FverbDict.txt'
  25. VERBDICT_MD5 = '0d2977293bbb6cbefab5b0f97db1e77c'
  26. TRGDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FtargetDict.txt'
  27. TRGDICT_MD5 = 'd8c7f03ceb5fc2e5a0fa7503a4353751'
  28. EMB_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2Femb'
  29. EMB_MD5 = 'bf436eb0faa1f6f9103017f8be57cdb7'
  30. UNK_IDX = 0
  31. class Conll05st(Dataset):
  32. """
  33. Implementation of `Conll05st <https://www.cs.upc.edu/~srlconll/soft.html>`_
  34. test dataset.
  35. Note: only support download test dataset automatically for that
  36. only test dataset of Conll05st is public.
  37. Args:
  38. data_file(str): path to data tar file, can be set None if
  39. :attr:`download` is True. Default None
  40. word_dict_file(str): path to word dictionary file, can be set None if
  41. :attr:`download` is True. Default None
  42. verb_dict_file(str): path to verb dictionary file, can be set None if
  43. :attr:`download` is True. Default None
  44. target_dict_file(str): path to target dictionary file, can be set None if
  45. :attr:`download` is True. Default None
  46. emb_file(str): path to embedding dictionary file, only used for
  47. :code:`get_embedding` can be set None if :attr:`download` is
  48. True. Default None
  49. download(bool): whether to download dataset automatically if
  50. :attr:`data_file` :attr:`word_dict_file` :attr:`verb_dict_file`
  51. :attr:`target_dict_file` is not set. Default True
  52. Returns:
  53. Dataset: instance of conll05st dataset
  54. Examples:
  55. .. code-block:: python
  56. >>> import paddle
  57. >>> from paddle.text.datasets import Conll05st
  58. >>> class SimpleNet(paddle.nn.Layer):
  59. ... def __init__(self):
  60. ... super().__init__()
  61. ...
  62. ... def forward(self, pred_idx, mark, label):
  63. ... return paddle.sum(pred_idx), paddle.sum(mark), paddle.sum(label)
  64. >>> conll05st = Conll05st()
  65. >>> for i in range(10):
  66. ... pred_idx, mark, label= conll05st[i][-3:]
  67. ... pred_idx = paddle.to_tensor(pred_idx)
  68. ... mark = paddle.to_tensor(mark)
  69. ... label = paddle.to_tensor(label)
  70. ...
  71. ... model = SimpleNet()
  72. ... pred_idx, mark, label= model(pred_idx, mark, label)
  73. ... print(pred_idx.item(), mark.item(), label.item())
  74. >>> # doctest: +SKIP('label will change')
  75. 65840 5 1991
  76. 92560 5 3686
  77. 99120 5 457
  78. 121960 5 3945
  79. 4774 5 2378
  80. 14973 5 1938
  81. 36921 5 1090
  82. 26908 5 2329
  83. 62965 5 2968
  84. 97755 5 2674
  85. """
  86. def __init__(
  87. self,
  88. data_file=None,
  89. word_dict_file=None,
  90. verb_dict_file=None,
  91. target_dict_file=None,
  92. emb_file=None,
  93. download=True,
  94. ):
  95. self.data_file = data_file
  96. if self.data_file is None:
  97. assert (
  98. download
  99. ), "data_file is not set and downloading automatically is disabled"
  100. self.data_file = _check_exists_and_download(
  101. data_file, DATA_URL, DATA_MD5, 'conll05st', download
  102. )
  103. self.word_dict_file = word_dict_file
  104. if self.word_dict_file is None:
  105. assert (
  106. download
  107. ), "word_dict_file is not set and downloading automatically is disabled"
  108. self.word_dict_file = _check_exists_and_download(
  109. word_dict_file,
  110. WORDDICT_URL,
  111. WORDDICT_MD5,
  112. 'conll05st',
  113. download,
  114. )
  115. self.verb_dict_file = verb_dict_file
  116. if self.verb_dict_file is None:
  117. assert (
  118. download
  119. ), "verb_dict_file is not set and downloading automatically is disabled"
  120. self.verb_dict_file = _check_exists_and_download(
  121. verb_dict_file,
  122. VERBDICT_URL,
  123. VERBDICT_MD5,
  124. 'conll05st',
  125. download,
  126. )
  127. self.target_dict_file = target_dict_file
  128. if self.target_dict_file is None:
  129. assert (
  130. download
  131. ), "target_dict_file is not set and downloading automatically is disabled"
  132. self.target_dict_file = _check_exists_and_download(
  133. target_dict_file,
  134. TRGDICT_URL,
  135. TRGDICT_MD5,
  136. 'conll05st',
  137. download,
  138. )
  139. self.emb_file = emb_file
  140. if self.emb_file is None:
  141. assert (
  142. download
  143. ), "emb_file is not set and downloading automatically is disabled"
  144. self.emb_file = _check_exists_and_download(
  145. emb_file, EMB_URL, EMB_MD5, 'conll05st', download
  146. )
  147. self.word_dict = self._load_dict(self.word_dict_file)
  148. self.predicate_dict = self._load_dict(self.verb_dict_file)
  149. self.label_dict = self._load_label_dict(self.target_dict_file)
  150. # read dataset into memory
  151. self._load_anno()
  152. def _load_label_dict(self, filename):
  153. d = {}
  154. tag_dict = set()
  155. with open(filename, 'r') as f:
  156. for i, line in enumerate(f):
  157. line = line.strip()
  158. if line.startswith("B-"):
  159. tag_dict.add(line[2:])
  160. elif line.startswith("I-"):
  161. tag_dict.add(line[2:])
  162. index = 0
  163. for tag in tag_dict:
  164. d["B-" + tag] = index
  165. index += 1
  166. d["I-" + tag] = index
  167. index += 1
  168. d["O"] = index
  169. return d
  170. def _load_dict(self, filename):
  171. d = {}
  172. with open(filename, 'r') as f:
  173. for i, line in enumerate(f):
  174. d[line.strip()] = i
  175. return d
  176. def _load_anno(self):
  177. tf = tarfile.open(self.data_file)
  178. wf = tf.extractfile(
  179. "conll05st-release/test.wsj/words/test.wsj.words.gz"
  180. )
  181. pf = tf.extractfile(
  182. "conll05st-release/test.wsj/props/test.wsj.props.gz"
  183. )
  184. self.sentences = []
  185. self.predicates = []
  186. self.labels = []
  187. with gzip.GzipFile(fileobj=wf) as words_file, gzip.GzipFile(
  188. fileobj=pf
  189. ) as props_file:
  190. sentences = []
  191. labels = []
  192. one_seg = []
  193. for word, label in zip(words_file, props_file):
  194. word = word.strip().decode()
  195. label = label.strip().decode().split()
  196. if len(label) == 0: # end of sentence
  197. for i in range(len(one_seg[0])):
  198. a_kind_lable = [x[i] for x in one_seg]
  199. labels.append(a_kind_lable)
  200. if len(labels) >= 1:
  201. verb_list = []
  202. for x in labels[0]:
  203. if x != '-':
  204. verb_list.append(x)
  205. for i, lbl in enumerate(labels[1:]):
  206. cur_tag = 'O'
  207. is_in_bracket = False
  208. lbl_seq = []
  209. verb_word = ''
  210. for l in lbl:
  211. if l == '*' and not is_in_bracket:
  212. lbl_seq.append('O')
  213. elif l == '*' and is_in_bracket:
  214. lbl_seq.append('I-' + cur_tag)
  215. elif l == '*)':
  216. lbl_seq.append('I-' + cur_tag)
  217. is_in_bracket = False
  218. elif l.find('(') != -1 and l.find(')') != -1:
  219. cur_tag = l[1 : l.find('*')]
  220. lbl_seq.append('B-' + cur_tag)
  221. is_in_bracket = False
  222. elif l.find('(') != -1 and l.find(')') == -1:
  223. cur_tag = l[1 : l.find('*')]
  224. lbl_seq.append('B-' + cur_tag)
  225. is_in_bracket = True
  226. else:
  227. raise RuntimeError(
  228. 'Unexpected label: %s' % l
  229. )
  230. self.sentences.append(sentences)
  231. self.predicates.append(verb_list[i])
  232. self.labels.append(lbl_seq)
  233. sentences = []
  234. labels = []
  235. one_seg = []
  236. else:
  237. sentences.append(word)
  238. one_seg.append(label)
  239. pf.close()
  240. wf.close()
  241. tf.close()
  242. def __getitem__(self, idx):
  243. sentence = self.sentences[idx]
  244. predicate = self.predicates[idx]
  245. labels = self.labels[idx]
  246. sen_len = len(sentence)
  247. verb_index = labels.index('B-V')
  248. mark = [0] * len(labels)
  249. if verb_index > 0:
  250. mark[verb_index - 1] = 1
  251. ctx_n1 = sentence[verb_index - 1]
  252. else:
  253. ctx_n1 = 'bos'
  254. if verb_index > 1:
  255. mark[verb_index - 2] = 1
  256. ctx_n2 = sentence[verb_index - 2]
  257. else:
  258. ctx_n2 = 'bos'
  259. mark[verb_index] = 1
  260. ctx_0 = sentence[verb_index]
  261. if verb_index < len(labels) - 1:
  262. mark[verb_index + 1] = 1
  263. ctx_p1 = sentence[verb_index + 1]
  264. else:
  265. ctx_p1 = 'eos'
  266. if verb_index < len(labels) - 2:
  267. mark[verb_index + 2] = 1
  268. ctx_p2 = sentence[verb_index + 2]
  269. else:
  270. ctx_p2 = 'eos'
  271. word_idx = [self.word_dict.get(w, UNK_IDX) for w in sentence]
  272. ctx_n2_idx = [self.word_dict.get(ctx_n2, UNK_IDX)] * sen_len
  273. ctx_n1_idx = [self.word_dict.get(ctx_n1, UNK_IDX)] * sen_len
  274. ctx_0_idx = [self.word_dict.get(ctx_0, UNK_IDX)] * sen_len
  275. ctx_p1_idx = [self.word_dict.get(ctx_p1, UNK_IDX)] * sen_len
  276. ctx_p2_idx = [self.word_dict.get(ctx_p2, UNK_IDX)] * sen_len
  277. pred_idx = [self.predicate_dict.get(predicate)] * sen_len
  278. label_idx = [self.label_dict.get(w) for w in labels]
  279. return (
  280. np.array(word_idx),
  281. np.array(ctx_n2_idx),
  282. np.array(ctx_n1_idx),
  283. np.array(ctx_0_idx),
  284. np.array(ctx_p1_idx),
  285. np.array(ctx_p2_idx),
  286. np.array(pred_idx),
  287. np.array(mark),
  288. np.array(label_idx),
  289. )
  290. def __len__(self):
  291. return len(self.sentences)
  292. def get_dict(self):
  293. """
  294. Get the word, verb and label dictionary of Wikipedia corpus.
  295. Examples:
  296. .. code-block:: python
  297. >>> from paddle.text.datasets import Conll05st
  298. >>> conll05st = Conll05st()
  299. >>> word_dict, predicate_dict, label_dict = conll05st.get_dict()
  300. """
  301. return self.word_dict, self.predicate_dict, self.label_dict
  302. def get_embedding(self):
  303. """
  304. Get the embedding dictionary file.
  305. Examples:
  306. .. code-block:: python
  307. >>> from paddle.text.datasets import Conll05st
  308. >>> conll05st = Conll05st()
  309. >>> emb_file = conll05st.get_embedding()
  310. """
  311. return self.emb_file