crf_head.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Any, Dict, List, Optional
  17. import torch
  18. import torch.nn.functional as F
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from transformers.activations import ACT2FN
  22. from modelscope.metainfo import Heads
  23. from modelscope.models.base import TorchHead
  24. from modelscope.models.builder import HEADS
  25. from modelscope.outputs import (AttentionTokenClassificationModelOutput,
  26. ModelOutputBase, OutputKeys,
  27. TokenClassificationModelOutput)
  28. from modelscope.utils.constant import Tasks
  29. @HEADS.register_module(Tasks.token_classification, module_name=Heads.lstm_crf)
  30. @HEADS.register_module(
  31. Tasks.named_entity_recognition, module_name=Heads.lstm_crf)
  32. @HEADS.register_module(Tasks.word_segmentation, module_name=Heads.lstm_crf)
  33. @HEADS.register_module(Tasks.part_of_speech, module_name=Heads.lstm_crf)
  34. class LSTMCRFHead(TorchHead):
  35. def __init__(self, hidden_size=100, num_labels=None, **kwargs):
  36. super().__init__(hidden_size=hidden_size, num_labels=num_labels)
  37. assert num_labels is not None
  38. self.ffn = nn.Linear(hidden_size * 2, num_labels)
  39. self.crf = CRF(num_labels, batch_first=True)
  40. def forward(self,
  41. inputs: ModelOutputBase,
  42. attention_mask=None,
  43. label=None,
  44. label_mask=None,
  45. offset_mapping=None,
  46. **kwargs):
  47. logits = self.ffn(inputs.last_hidden_state)
  48. return TokenClassificationModelOutput(
  49. loss=None,
  50. logits=logits,
  51. )
  52. def decode(self, logits, label_mask):
  53. seq_lens = label_mask.sum(-1).long()
  54. mask = torch.arange(
  55. label_mask.shape[1],
  56. device=seq_lens.device)[None, :] < seq_lens[:, None]
  57. predicts = self.crf.decode(logits, mask).squeeze(0)
  58. return predicts
  59. @HEADS.register_module(
  60. Tasks.transformer_crf, module_name=Heads.transformer_crf)
  61. @HEADS.register_module(
  62. Tasks.token_classification, module_name=Heads.transformer_crf)
  63. @HEADS.register_module(
  64. Tasks.named_entity_recognition, module_name=Heads.transformer_crf)
  65. @HEADS.register_module(
  66. Tasks.word_segmentation, module_name=Heads.transformer_crf)
  67. @HEADS.register_module(Tasks.part_of_speech, module_name=Heads.transformer_crf)
  68. class TransformersCRFHead(TorchHead):
  69. def __init__(self, hidden_size, num_labels, **kwargs):
  70. super().__init__(
  71. hidden_size=hidden_size, num_labels=num_labels, **kwargs)
  72. self.linear = nn.Linear(hidden_size, num_labels)
  73. self.crf = CRF(num_labels, batch_first=True)
  74. def forward(self,
  75. inputs: ModelOutputBase,
  76. attention_mask=None,
  77. label=None,
  78. label_mask=None,
  79. offset_mapping=None,
  80. **kwargs):
  81. logits = self.linear(inputs.last_hidden_state)
  82. if label_mask is not None:
  83. mask = label_mask
  84. masked_lengths = mask.sum(-1).long()
  85. masked_logits = torch.zeros_like(logits)
  86. for i in range(mask.shape[0]):
  87. masked_logits[
  88. i, :masked_lengths[i], :] = logits[i].masked_select(
  89. mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)
  90. logits = masked_logits
  91. return AttentionTokenClassificationModelOutput(
  92. loss=None,
  93. logits=logits,
  94. hidden_states=inputs.hidden_states,
  95. attentions=inputs.attentions,
  96. )
  97. def decode(self, logits, label_mask):
  98. seq_lens = label_mask.sum(-1).long()
  99. mask = torch.arange(
  100. label_mask.shape[1],
  101. device=seq_lens.device)[None, :] < seq_lens[:, None]
  102. predicts = self.crf.decode(logits, mask).squeeze(0)
  103. return predicts
  104. class CRF(nn.Module):
  105. """Conditional random field.
  106. This module implements a conditional random field [LMP01]_. The forward computation
  107. of this class computes the log likelihood of the given sequence of tags and
  108. emission score tensor. This class also has `~CRF.decode` method which finds
  109. the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
  110. Args:
  111. num_tags: Number of tags.
  112. batch_first: Whether the first dimension corresponds to the size of a minibatch.
  113. Attributes:
  114. start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
  115. ``(num_tags,)``.
  116. end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
  117. ``(num_tags,)``.
  118. transitions (`~torch.nn.Parameter`): Transition score tensor of size
  119. ``(num_tags, num_tags)``.
  120. .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
  121. "Conditional random fields: Probabilistic models for segmenting and
  122. labeling sequence data". *Proc. 18th International Conf. on Machine
  123. Learning*. Morgan Kaufmann. pp. 282–289.
  124. .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
  125. """
  126. def __init__(self, num_tags: int, batch_first: bool = False) -> None:
  127. if num_tags <= 0:
  128. raise ValueError(f'invalid number of tags: {num_tags}')
  129. super().__init__()
  130. self.num_tags = num_tags
  131. self.batch_first = batch_first
  132. self.start_transitions = nn.Parameter(torch.empty(num_tags))
  133. self.end_transitions = nn.Parameter(torch.empty(num_tags))
  134. self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
  135. self.reset_parameters()
  136. def reset_parameters(self) -> None:
  137. """Initialize the transition parameters.
  138. The parameters will be initialized randomly from a uniform distribution
  139. between -0.1 and 0.1.
  140. """
  141. nn.init.uniform_(self.start_transitions, -0.1, 0.1)
  142. nn.init.uniform_(self.end_transitions, -0.1, 0.1)
  143. nn.init.uniform_(self.transitions, -0.1, 0.1)
  144. def __repr__(self) -> str:
  145. return f'{self.__class__.__name__}(num_tags={self.num_tags})'
  146. def forward(self,
  147. emissions: torch.Tensor,
  148. tags: torch.LongTensor,
  149. mask: Optional[torch.ByteTensor] = None,
  150. reduction: str = 'mean') -> torch.Tensor:
  151. """Compute the conditional log likelihood of a sequence of tags given emission scores.
  152. Args:
  153. emissions (`~torch.Tensor`): Emission score tensor of size
  154. ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
  155. ``(batch_size, seq_length, num_tags)`` otherwise.
  156. tags (`~torch.LongTensor`): Sequence of tags tensor of size
  157. ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
  158. ``(batch_size, seq_length)`` otherwise.
  159. mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
  160. if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
  161. reduction: Specifies the reduction to apply to the output:
  162. ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
  163. ``sum``: the output will be summed over batches. ``mean``: the output will be
  164. averaged over batches. ``token_mean``: the output will be averaged over tokens.
  165. Returns:
  166. `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
  167. reduction is ``none``, ``()`` otherwise.
  168. """
  169. if reduction not in ('none', 'sum', 'mean', 'token_mean'):
  170. raise ValueError(f'invalid reduction: {reduction}')
  171. if mask is None:
  172. mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
  173. if mask.dtype != torch.uint8:
  174. mask = mask.byte()
  175. self._validate(emissions, tags=tags, mask=mask)
  176. if self.batch_first:
  177. emissions = emissions.transpose(0, 1)
  178. tags = tags.transpose(0, 1)
  179. mask = mask.transpose(0, 1)
  180. # shape: (batch_size,)
  181. numerator = self._compute_score(emissions, tags, mask)
  182. # shape: (batch_size,)
  183. denominator = self._compute_normalizer(emissions, mask)
  184. # shape: (batch_size,)
  185. llh = numerator - denominator
  186. if reduction == 'none':
  187. return llh
  188. if reduction == 'sum':
  189. return llh.sum()
  190. if reduction == 'mean':
  191. return llh.mean()
  192. return llh.sum() / mask.float().sum()
  193. def decode(self,
  194. emissions: torch.Tensor,
  195. mask: Optional[torch.ByteTensor] = None,
  196. nbest: Optional[int] = None,
  197. pad_tag: Optional[int] = None) -> List[List[List[int]]]:
  198. """Find the most likely tag sequence using Viterbi algorithm.
  199. Args:
  200. emissions (`~torch.Tensor`): Emission score tensor of size
  201. ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
  202. ``(batch_size, seq_length, num_tags)`` otherwise.
  203. mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
  204. if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
  205. nbest (`int`): Number of most probable paths for each sequence
  206. pad_tag (`int`): Tag at padded positions. Often input varies in length and
  207. the length will be padded to the maximum length in the batch. Tags at
  208. the padded positions will be assigned with a padding tag, i.e. `pad_tag`
  209. Returns:
  210. A PyTorch tensor of the best tag sequence for each batch of shape
  211. (nbest, batch_size, seq_length)
  212. """
  213. if nbest is None:
  214. nbest = 1
  215. if mask is None:
  216. mask = torch.ones(
  217. emissions.shape[:2],
  218. dtype=torch.uint8,
  219. device=emissions.device)
  220. if mask.dtype != torch.uint8:
  221. mask = mask.byte()
  222. self._validate(emissions, mask=mask)
  223. if self.batch_first:
  224. emissions = emissions.transpose(0, 1)
  225. mask = mask.transpose(0, 1)
  226. if nbest == 1:
  227. return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
  228. return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
  229. def _validate(self,
  230. emissions: torch.Tensor,
  231. tags: Optional[torch.LongTensor] = None,
  232. mask: Optional[torch.ByteTensor] = None) -> None:
  233. if emissions.dim() != 3:
  234. raise ValueError(
  235. f'emissions must have dimension of 3, got {emissions.dim()}')
  236. if emissions.size(2) != self.num_tags:
  237. raise ValueError(
  238. f'expected last dimension of emissions is {self.num_tags}, '
  239. f'got {emissions.size(2)}')
  240. if tags is not None:
  241. if emissions.shape[:2] != tags.shape:
  242. raise ValueError(
  243. 'the first two dimensions of emissions and tags must match, '
  244. f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}'
  245. )
  246. if mask is not None:
  247. if emissions.shape[:2] != mask.shape:
  248. raise ValueError(
  249. 'the first two dimensions of emissions and mask must match, '
  250. f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}'
  251. )
  252. no_empty_seq = not self.batch_first and mask[0].all()
  253. no_empty_seq_bf = self.batch_first and mask[:, 0].all()
  254. if not no_empty_seq and not no_empty_seq_bf:
  255. raise ValueError('mask of the first timestep must all be on')
  256. def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor,
  257. mask: torch.ByteTensor) -> torch.Tensor:
  258. # emissions: (seq_length, batch_size, num_tags)
  259. # tags: (seq_length, batch_size)
  260. # mask: (seq_length, batch_size)
  261. seq_length, batch_size = tags.shape
  262. mask = mask.float()
  263. # Start transition score and first emission
  264. # shape: (batch_size,)
  265. score = self.start_transitions[tags[0]]
  266. score += emissions[0, torch.arange(batch_size), tags[0]]
  267. for i in range(1, seq_length):
  268. # Transition score to next tag, only added if next timestep is valid (mask == 1)
  269. # shape: (batch_size,)
  270. score += self.transitions[tags[i - 1], tags[i]] * mask[i]
  271. # Emission score for next tag, only added if next timestep is valid (mask == 1)
  272. # shape: (batch_size,)
  273. score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
  274. # End transition score
  275. # shape: (batch_size,)
  276. seq_ends = mask.long().sum(dim=0) - 1
  277. # shape: (batch_size,)
  278. last_tags = tags[seq_ends, torch.arange(batch_size)]
  279. # shape: (batch_size,)
  280. score += self.end_transitions[last_tags]
  281. return score
  282. def _compute_normalizer(self, emissions: torch.Tensor,
  283. mask: torch.ByteTensor) -> torch.Tensor:
  284. # emissions: (seq_length, batch_size, num_tags)
  285. # mask: (seq_length, batch_size)
  286. seq_length = emissions.size(0)
  287. # Start transition score and first emission; score has size of
  288. # (batch_size, num_tags) where for each batch, the j-th column stores
  289. # the score that the first timestep has tag j
  290. # shape: (batch_size, num_tags)
  291. score = self.start_transitions + emissions[0]
  292. for i in range(1, seq_length):
  293. # Broadcast score for every possible next tag
  294. # shape: (batch_size, num_tags, 1)
  295. broadcast_score = score.unsqueeze(2)
  296. # Broadcast emission score for every possible current tag
  297. # shape: (batch_size, 1, num_tags)
  298. broadcast_emissions = emissions[i].unsqueeze(1)
  299. # Compute the score tensor of size (batch_size, num_tags, num_tags) where
  300. # for each sample, entry at row i and column j stores the sum of scores of all
  301. # possible tag sequences so far that end with transitioning from tag i to tag j
  302. # and emitting
  303. # shape: (batch_size, num_tags, num_tags)
  304. next_score = broadcast_score + self.transitions + broadcast_emissions
  305. # Sum over all possible current tags, but we're in score space, so a sum
  306. # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
  307. # all possible tag sequences so far, that end in tag i
  308. # shape: (batch_size, num_tags)
  309. next_score = torch.logsumexp(next_score, dim=1)
  310. # Set score to the next score if this timestep is valid (mask == 1)
  311. # shape: (batch_size, num_tags)
  312. score = torch.where(mask[i].unsqueeze(1), next_score, score)
  313. # End transition score
  314. # shape: (batch_size, num_tags)
  315. score += self.end_transitions
  316. # Sum (log-sum-exp) over all possible tags
  317. # shape: (batch_size,)
  318. return torch.logsumexp(score, dim=1)
  319. def _viterbi_decode(self,
  320. emissions: torch.FloatTensor,
  321. mask: torch.ByteTensor,
  322. pad_tag: Optional[int] = None) -> List[List[int]]:
  323. # emissions: (seq_length, batch_size, num_tags)
  324. # mask: (seq_length, batch_size)
  325. # return: (batch_size, seq_length)
  326. if pad_tag is None:
  327. pad_tag = 0
  328. device = emissions.device
  329. seq_length, batch_size = mask.shape
  330. # Start transition and first emission
  331. # shape: (batch_size, num_tags)
  332. score = self.start_transitions + emissions[0]
  333. history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
  334. dtype=torch.long,
  335. device=device)
  336. oor_idx = torch.zeros((batch_size, self.num_tags),
  337. dtype=torch.long,
  338. device=device)
  339. oor_tag = torch.full((seq_length, batch_size),
  340. pad_tag,
  341. dtype=torch.long,
  342. device=device)
  343. # - score is a tensor of size (batch_size, num_tags) where for every batch,
  344. # value at column j stores the score of the best tag sequence so far that ends
  345. # with tag j
  346. # - history_idx saves where the best tags candidate transitioned from; this is used
  347. # when we trace back the best tag sequence
  348. # - oor_idx saves the best tags candidate transitioned from at the positions
  349. # where mask is 0, i.e. out of range (oor)
  350. # Viterbi algorithm recursive case: we compute the score of the best tag sequence
  351. # for every possible next tag
  352. for i in range(1, seq_length):
  353. # Broadcast viterbi score for every possible next tag
  354. # shape: (batch_size, num_tags, 1)
  355. broadcast_score = score.unsqueeze(2)
  356. # Broadcast emission score for every possible current tag
  357. # shape: (batch_size, 1, num_tags)
  358. broadcast_emission = emissions[i].unsqueeze(1)
  359. # Compute the score tensor of size (batch_size, num_tags, num_tags) where
  360. # for each sample, entry at row i and column j stores the score of the best
  361. # tag sequence so far that ends with transitioning from tag i to tag j and emitting
  362. # shape: (batch_size, num_tags, num_tags)
  363. next_score = broadcast_score + self.transitions + broadcast_emission
  364. # Find the maximum score over all possible current tag
  365. # shape: (batch_size, num_tags)
  366. next_score, indices = next_score.max(dim=1)
  367. # Set score to the next score if this timestep is valid (mask == 1)
  368. # and save the index that produces the next score
  369. # shape: (batch_size, num_tags)
  370. score = torch.where(mask[i].unsqueeze(-1).bool(), next_score,
  371. score)
  372. indices = torch.where(mask[i].unsqueeze(-1).bool(), indices,
  373. oor_idx)
  374. history_idx[i - 1] = indices
  375. # End transition score
  376. # shape: (batch_size, num_tags)
  377. end_score = score + self.end_transitions
  378. _, end_tag = end_score.max(dim=1)
  379. # shape: (batch_size,)
  380. seq_ends = mask.long().sum(dim=0) - 1
  381. # insert the best tag at each sequence end (last position with mask == 1)
  382. history_idx = history_idx.transpose(1, 0).contiguous()
  383. history_idx.scatter_(
  384. 1,
  385. seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
  386. end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
  387. history_idx = history_idx.transpose(1, 0).contiguous()
  388. # The most probable path for each sequence
  389. best_tags_arr = torch.zeros((seq_length, batch_size),
  390. dtype=torch.long,
  391. device=device)
  392. best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
  393. for idx in range(seq_length - 1, -1, -1):
  394. best_tags = torch.gather(history_idx[idx], 1, best_tags)
  395. best_tags_arr[idx] = best_tags.data.view(batch_size)
  396. return torch.where(mask.bool(), best_tags_arr, oor_tag).transpose(0, 1)
  397. def _viterbi_decode_nbest(
  398. self,
  399. emissions: torch.FloatTensor,
  400. mask: torch.ByteTensor,
  401. nbest: int,
  402. pad_tag: Optional[int] = None) -> List[List[List[int]]]:
  403. # emissions: (seq_length, batch_size, num_tags)
  404. # mask: (seq_length, batch_size)
  405. # return: (nbest, batch_size, seq_length)
  406. if pad_tag is None:
  407. pad_tag = 0
  408. device = emissions.device
  409. seq_length, batch_size = mask.shape
  410. # Start transition and first emission
  411. # shape: (batch_size, num_tags)
  412. score = self.start_transitions + emissions[0]
  413. history_idx = torch.zeros(
  414. (seq_length, batch_size, self.num_tags, nbest),
  415. dtype=torch.long,
  416. device=device)
  417. oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
  418. dtype=torch.long,
  419. device=device)
  420. oor_tag = torch.full((seq_length, batch_size, nbest),
  421. pad_tag,
  422. dtype=torch.long,
  423. device=device)
  424. # + score is a tensor of size (batch_size, num_tags) where for every batch,
  425. # value at column j stores the score of the best tag sequence so far that ends
  426. # with tag j
  427. # + history_idx saves where the best tags candidate transitioned from; this is used
  428. # when we trace back the best tag sequence
  429. # - oor_idx saves the best tags candidate transitioned from at the positions
  430. # where mask is 0, i.e. out of range (oor)
  431. # Viterbi algorithm recursive case: we compute the score of the best tag sequence
  432. # for every possible next tag
  433. for i in range(1, seq_length):
  434. if i == 1:
  435. broadcast_score = score.unsqueeze(-1)
  436. broadcast_emission = emissions[i].unsqueeze(1)
  437. # shape: (batch_size, num_tags, num_tags)
  438. next_score = broadcast_score + self.transitions + broadcast_emission
  439. else:
  440. broadcast_score = score.unsqueeze(-1)
  441. broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
  442. # shape: (batch_size, num_tags, nbest, num_tags)
  443. next_score = broadcast_score + self.transitions.unsqueeze(
  444. 1) + broadcast_emission
  445. # Find the top `nbest` maximum score over all possible current tag
  446. # shape: (batch_size, nbest, num_tags)
  447. next_score, indices = next_score.view(batch_size, -1,
  448. self.num_tags).topk(
  449. nbest, dim=1)
  450. if i == 1:
  451. score = score.unsqueeze(-1).expand(-1, -1, nbest)
  452. indices = indices * nbest
  453. # convert to shape: (batch_size, num_tags, nbest)
  454. next_score = next_score.transpose(2, 1)
  455. indices = indices.transpose(2, 1)
  456. # Set score to the next score if this timestep is valid (mask == 1)
  457. # and save the index that produces the next score
  458. # shape: (batch_size, num_tags, nbest)
  459. score = torch.where(mask[i].unsqueeze(-1).bool().unsqueeze(-1),
  460. next_score, score)
  461. indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1).bool(),
  462. indices, oor_idx)
  463. history_idx[i - 1] = indices
  464. # End transition score shape: (batch_size, num_tags, nbest)
  465. end_score = score + self.end_transitions.unsqueeze(-1)
  466. _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
  467. # shape: (batch_size,)
  468. seq_ends = mask.long().sum(dim=0) - 1
  469. # insert the best tag at each sequence end (last position with mask == 1)
  470. history_idx = history_idx.transpose(1, 0).contiguous()
  471. history_idx.scatter_(
  472. 1,
  473. seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
  474. end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
  475. history_idx = history_idx.transpose(1, 0).contiguous()
  476. # The most probable path for each sequence
  477. best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
  478. dtype=torch.long,
  479. device=device)
  480. best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
  481. .view(1, -1).expand(batch_size, -1)
  482. for idx in range(seq_length - 1, -1, -1):
  483. best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1,
  484. best_tags)
  485. best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest
  486. return torch.where(mask.unsqueeze(-1), best_tags_arr,
  487. oor_tag).permute(2, 1, 0)