| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- #
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, Dict, List, Optional
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from transformers.activations import ACT2FN
- from modelscope.metainfo import Heads
- from modelscope.models.base import TorchHead
- from modelscope.models.builder import HEADS
- from modelscope.outputs import (AttentionTokenClassificationModelOutput,
- ModelOutputBase, OutputKeys,
- TokenClassificationModelOutput)
- from modelscope.utils.constant import Tasks
- @HEADS.register_module(Tasks.token_classification, module_name=Heads.lstm_crf)
- @HEADS.register_module(
- Tasks.named_entity_recognition, module_name=Heads.lstm_crf)
- @HEADS.register_module(Tasks.word_segmentation, module_name=Heads.lstm_crf)
- @HEADS.register_module(Tasks.part_of_speech, module_name=Heads.lstm_crf)
- class LSTMCRFHead(TorchHead):
- def __init__(self, hidden_size=100, num_labels=None, **kwargs):
- super().__init__(hidden_size=hidden_size, num_labels=num_labels)
- assert num_labels is not None
- self.ffn = nn.Linear(hidden_size * 2, num_labels)
- self.crf = CRF(num_labels, batch_first=True)
- def forward(self,
- inputs: ModelOutputBase,
- attention_mask=None,
- label=None,
- label_mask=None,
- offset_mapping=None,
- **kwargs):
- logits = self.ffn(inputs.last_hidden_state)
- return TokenClassificationModelOutput(
- loss=None,
- logits=logits,
- )
- def decode(self, logits, label_mask):
- seq_lens = label_mask.sum(-1).long()
- mask = torch.arange(
- label_mask.shape[1],
- device=seq_lens.device)[None, :] < seq_lens[:, None]
- predicts = self.crf.decode(logits, mask).squeeze(0)
- return predicts
- @HEADS.register_module(
- Tasks.transformer_crf, module_name=Heads.transformer_crf)
- @HEADS.register_module(
- Tasks.token_classification, module_name=Heads.transformer_crf)
- @HEADS.register_module(
- Tasks.named_entity_recognition, module_name=Heads.transformer_crf)
- @HEADS.register_module(
- Tasks.word_segmentation, module_name=Heads.transformer_crf)
- @HEADS.register_module(Tasks.part_of_speech, module_name=Heads.transformer_crf)
- class TransformersCRFHead(TorchHead):
- def __init__(self, hidden_size, num_labels, **kwargs):
- super().__init__(
- hidden_size=hidden_size, num_labels=num_labels, **kwargs)
- self.linear = nn.Linear(hidden_size, num_labels)
- self.crf = CRF(num_labels, batch_first=True)
- def forward(self,
- inputs: ModelOutputBase,
- attention_mask=None,
- label=None,
- label_mask=None,
- offset_mapping=None,
- **kwargs):
- logits = self.linear(inputs.last_hidden_state)
- if label_mask is not None:
- mask = label_mask
- masked_lengths = mask.sum(-1).long()
- masked_logits = torch.zeros_like(logits)
- for i in range(mask.shape[0]):
- masked_logits[
- i, :masked_lengths[i], :] = logits[i].masked_select(
- mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)
- logits = masked_logits
- return AttentionTokenClassificationModelOutput(
- loss=None,
- logits=logits,
- hidden_states=inputs.hidden_states,
- attentions=inputs.attentions,
- )
- def decode(self, logits, label_mask):
- seq_lens = label_mask.sum(-1).long()
- mask = torch.arange(
- label_mask.shape[1],
- device=seq_lens.device)[None, :] < seq_lens[:, None]
- predicts = self.crf.decode(logits, mask).squeeze(0)
- return predicts
- class CRF(nn.Module):
- """Conditional random field.
- This module implements a conditional random field [LMP01]_. The forward computation
- of this class computes the log likelihood of the given sequence of tags and
- emission score tensor. This class also has `~CRF.decode` method which finds
- the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
- Args:
- num_tags: Number of tags.
- batch_first: Whether the first dimension corresponds to the size of a minibatch.
- Attributes:
- start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
- ``(num_tags,)``.
- end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
- ``(num_tags,)``.
- transitions (`~torch.nn.Parameter`): Transition score tensor of size
- ``(num_tags, num_tags)``.
- .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
- "Conditional random fields: Probabilistic models for segmenting and
- labeling sequence data". *Proc. 18th International Conf. on Machine
- Learning*. Morgan Kaufmann. pp. 282–289.
- .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
- """
- def __init__(self, num_tags: int, batch_first: bool = False) -> None:
- if num_tags <= 0:
- raise ValueError(f'invalid number of tags: {num_tags}')
- super().__init__()
- self.num_tags = num_tags
- self.batch_first = batch_first
- self.start_transitions = nn.Parameter(torch.empty(num_tags))
- self.end_transitions = nn.Parameter(torch.empty(num_tags))
- self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
- self.reset_parameters()
- def reset_parameters(self) -> None:
- """Initialize the transition parameters.
- The parameters will be initialized randomly from a uniform distribution
- between -0.1 and 0.1.
- """
- nn.init.uniform_(self.start_transitions, -0.1, 0.1)
- nn.init.uniform_(self.end_transitions, -0.1, 0.1)
- nn.init.uniform_(self.transitions, -0.1, 0.1)
- def __repr__(self) -> str:
- return f'{self.__class__.__name__}(num_tags={self.num_tags})'
- def forward(self,
- emissions: torch.Tensor,
- tags: torch.LongTensor,
- mask: Optional[torch.ByteTensor] = None,
- reduction: str = 'mean') -> torch.Tensor:
- """Compute the conditional log likelihood of a sequence of tags given emission scores.
- Args:
- emissions (`~torch.Tensor`): Emission score tensor of size
- ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
- ``(batch_size, seq_length, num_tags)`` otherwise.
- tags (`~torch.LongTensor`): Sequence of tags tensor of size
- ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
- ``(batch_size, seq_length)`` otherwise.
- mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
- if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
- reduction: Specifies the reduction to apply to the output:
- ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
- ``sum``: the output will be summed over batches. ``mean``: the output will be
- averaged over batches. ``token_mean``: the output will be averaged over tokens.
- Returns:
- `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
- reduction is ``none``, ``()`` otherwise.
- """
- if reduction not in ('none', 'sum', 'mean', 'token_mean'):
- raise ValueError(f'invalid reduction: {reduction}')
- if mask is None:
- mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
- if mask.dtype != torch.uint8:
- mask = mask.byte()
- self._validate(emissions, tags=tags, mask=mask)
- if self.batch_first:
- emissions = emissions.transpose(0, 1)
- tags = tags.transpose(0, 1)
- mask = mask.transpose(0, 1)
- # shape: (batch_size,)
- numerator = self._compute_score(emissions, tags, mask)
- # shape: (batch_size,)
- denominator = self._compute_normalizer(emissions, mask)
- # shape: (batch_size,)
- llh = numerator - denominator
- if reduction == 'none':
- return llh
- if reduction == 'sum':
- return llh.sum()
- if reduction == 'mean':
- return llh.mean()
- return llh.sum() / mask.float().sum()
- def decode(self,
- emissions: torch.Tensor,
- mask: Optional[torch.ByteTensor] = None,
- nbest: Optional[int] = None,
- pad_tag: Optional[int] = None) -> List[List[List[int]]]:
- """Find the most likely tag sequence using Viterbi algorithm.
- Args:
- emissions (`~torch.Tensor`): Emission score tensor of size
- ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
- ``(batch_size, seq_length, num_tags)`` otherwise.
- mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
- if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
- nbest (`int`): Number of most probable paths for each sequence
- pad_tag (`int`): Tag at padded positions. Often input varies in length and
- the length will be padded to the maximum length in the batch. Tags at
- the padded positions will be assigned with a padding tag, i.e. `pad_tag`
- Returns:
- A PyTorch tensor of the best tag sequence for each batch of shape
- (nbest, batch_size, seq_length)
- """
- if nbest is None:
- nbest = 1
- if mask is None:
- mask = torch.ones(
- emissions.shape[:2],
- dtype=torch.uint8,
- device=emissions.device)
- if mask.dtype != torch.uint8:
- mask = mask.byte()
- self._validate(emissions, mask=mask)
- if self.batch_first:
- emissions = emissions.transpose(0, 1)
- mask = mask.transpose(0, 1)
- if nbest == 1:
- return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
- return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
- def _validate(self,
- emissions: torch.Tensor,
- tags: Optional[torch.LongTensor] = None,
- mask: Optional[torch.ByteTensor] = None) -> None:
- if emissions.dim() != 3:
- raise ValueError(
- f'emissions must have dimension of 3, got {emissions.dim()}')
- if emissions.size(2) != self.num_tags:
- raise ValueError(
- f'expected last dimension of emissions is {self.num_tags}, '
- f'got {emissions.size(2)}')
- if tags is not None:
- if emissions.shape[:2] != tags.shape:
- raise ValueError(
- 'the first two dimensions of emissions and tags must match, '
- f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}'
- )
- if mask is not None:
- if emissions.shape[:2] != mask.shape:
- raise ValueError(
- 'the first two dimensions of emissions and mask must match, '
- f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}'
- )
- no_empty_seq = not self.batch_first and mask[0].all()
- no_empty_seq_bf = self.batch_first and mask[:, 0].all()
- if not no_empty_seq and not no_empty_seq_bf:
- raise ValueError('mask of the first timestep must all be on')
- def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor,
- mask: torch.ByteTensor) -> torch.Tensor:
- # emissions: (seq_length, batch_size, num_tags)
- # tags: (seq_length, batch_size)
- # mask: (seq_length, batch_size)
- seq_length, batch_size = tags.shape
- mask = mask.float()
- # Start transition score and first emission
- # shape: (batch_size,)
- score = self.start_transitions[tags[0]]
- score += emissions[0, torch.arange(batch_size), tags[0]]
- for i in range(1, seq_length):
- # Transition score to next tag, only added if next timestep is valid (mask == 1)
- # shape: (batch_size,)
- score += self.transitions[tags[i - 1], tags[i]] * mask[i]
- # Emission score for next tag, only added if next timestep is valid (mask == 1)
- # shape: (batch_size,)
- score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
- # End transition score
- # shape: (batch_size,)
- seq_ends = mask.long().sum(dim=0) - 1
- # shape: (batch_size,)
- last_tags = tags[seq_ends, torch.arange(batch_size)]
- # shape: (batch_size,)
- score += self.end_transitions[last_tags]
- return score
- def _compute_normalizer(self, emissions: torch.Tensor,
- mask: torch.ByteTensor) -> torch.Tensor:
- # emissions: (seq_length, batch_size, num_tags)
- # mask: (seq_length, batch_size)
- seq_length = emissions.size(0)
- # Start transition score and first emission; score has size of
- # (batch_size, num_tags) where for each batch, the j-th column stores
- # the score that the first timestep has tag j
- # shape: (batch_size, num_tags)
- score = self.start_transitions + emissions[0]
- for i in range(1, seq_length):
- # Broadcast score for every possible next tag
- # shape: (batch_size, num_tags, 1)
- broadcast_score = score.unsqueeze(2)
- # Broadcast emission score for every possible current tag
- # shape: (batch_size, 1, num_tags)
- broadcast_emissions = emissions[i].unsqueeze(1)
- # Compute the score tensor of size (batch_size, num_tags, num_tags) where
- # for each sample, entry at row i and column j stores the sum of scores of all
- # possible tag sequences so far that end with transitioning from tag i to tag j
- # and emitting
- # shape: (batch_size, num_tags, num_tags)
- next_score = broadcast_score + self.transitions + broadcast_emissions
- # Sum over all possible current tags, but we're in score space, so a sum
- # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
- # all possible tag sequences so far, that end in tag i
- # shape: (batch_size, num_tags)
- next_score = torch.logsumexp(next_score, dim=1)
- # Set score to the next score if this timestep is valid (mask == 1)
- # shape: (batch_size, num_tags)
- score = torch.where(mask[i].unsqueeze(1), next_score, score)
- # End transition score
- # shape: (batch_size, num_tags)
- score += self.end_transitions
- # Sum (log-sum-exp) over all possible tags
- # shape: (batch_size,)
- return torch.logsumexp(score, dim=1)
- def _viterbi_decode(self,
- emissions: torch.FloatTensor,
- mask: torch.ByteTensor,
- pad_tag: Optional[int] = None) -> List[List[int]]:
- # emissions: (seq_length, batch_size, num_tags)
- # mask: (seq_length, batch_size)
- # return: (batch_size, seq_length)
- if pad_tag is None:
- pad_tag = 0
- device = emissions.device
- seq_length, batch_size = mask.shape
- # Start transition and first emission
- # shape: (batch_size, num_tags)
- score = self.start_transitions + emissions[0]
- history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
- dtype=torch.long,
- device=device)
- oor_idx = torch.zeros((batch_size, self.num_tags),
- dtype=torch.long,
- device=device)
- oor_tag = torch.full((seq_length, batch_size),
- pad_tag,
- dtype=torch.long,
- device=device)
- # - score is a tensor of size (batch_size, num_tags) where for every batch,
- # value at column j stores the score of the best tag sequence so far that ends
- # with tag j
- # - history_idx saves where the best tags candidate transitioned from; this is used
- # when we trace back the best tag sequence
- # - oor_idx saves the best tags candidate transitioned from at the positions
- # where mask is 0, i.e. out of range (oor)
- # Viterbi algorithm recursive case: we compute the score of the best tag sequence
- # for every possible next tag
- for i in range(1, seq_length):
- # Broadcast viterbi score for every possible next tag
- # shape: (batch_size, num_tags, 1)
- broadcast_score = score.unsqueeze(2)
- # Broadcast emission score for every possible current tag
- # shape: (batch_size, 1, num_tags)
- broadcast_emission = emissions[i].unsqueeze(1)
- # Compute the score tensor of size (batch_size, num_tags, num_tags) where
- # for each sample, entry at row i and column j stores the score of the best
- # tag sequence so far that ends with transitioning from tag i to tag j and emitting
- # shape: (batch_size, num_tags, num_tags)
- next_score = broadcast_score + self.transitions + broadcast_emission
- # Find the maximum score over all possible current tag
- # shape: (batch_size, num_tags)
- next_score, indices = next_score.max(dim=1)
- # Set score to the next score if this timestep is valid (mask == 1)
- # and save the index that produces the next score
- # shape: (batch_size, num_tags)
- score = torch.where(mask[i].unsqueeze(-1).bool(), next_score,
- score)
- indices = torch.where(mask[i].unsqueeze(-1).bool(), indices,
- oor_idx)
- history_idx[i - 1] = indices
- # End transition score
- # shape: (batch_size, num_tags)
- end_score = score + self.end_transitions
- _, end_tag = end_score.max(dim=1)
- # shape: (batch_size,)
- seq_ends = mask.long().sum(dim=0) - 1
- # insert the best tag at each sequence end (last position with mask == 1)
- history_idx = history_idx.transpose(1, 0).contiguous()
- history_idx.scatter_(
- 1,
- seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
- end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
- history_idx = history_idx.transpose(1, 0).contiguous()
- # The most probable path for each sequence
- best_tags_arr = torch.zeros((seq_length, batch_size),
- dtype=torch.long,
- device=device)
- best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
- for idx in range(seq_length - 1, -1, -1):
- best_tags = torch.gather(history_idx[idx], 1, best_tags)
- best_tags_arr[idx] = best_tags.data.view(batch_size)
- return torch.where(mask.bool(), best_tags_arr, oor_tag).transpose(0, 1)
- def _viterbi_decode_nbest(
- self,
- emissions: torch.FloatTensor,
- mask: torch.ByteTensor,
- nbest: int,
- pad_tag: Optional[int] = None) -> List[List[List[int]]]:
- # emissions: (seq_length, batch_size, num_tags)
- # mask: (seq_length, batch_size)
- # return: (nbest, batch_size, seq_length)
- if pad_tag is None:
- pad_tag = 0
- device = emissions.device
- seq_length, batch_size = mask.shape
- # Start transition and first emission
- # shape: (batch_size, num_tags)
- score = self.start_transitions + emissions[0]
- history_idx = torch.zeros(
- (seq_length, batch_size, self.num_tags, nbest),
- dtype=torch.long,
- device=device)
- oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
- dtype=torch.long,
- device=device)
- oor_tag = torch.full((seq_length, batch_size, nbest),
- pad_tag,
- dtype=torch.long,
- device=device)
- # + score is a tensor of size (batch_size, num_tags) where for every batch,
- # value at column j stores the score of the best tag sequence so far that ends
- # with tag j
- # + history_idx saves where the best tags candidate transitioned from; this is used
- # when we trace back the best tag sequence
- # - oor_idx saves the best tags candidate transitioned from at the positions
- # where mask is 0, i.e. out of range (oor)
- # Viterbi algorithm recursive case: we compute the score of the best tag sequence
- # for every possible next tag
- for i in range(1, seq_length):
- if i == 1:
- broadcast_score = score.unsqueeze(-1)
- broadcast_emission = emissions[i].unsqueeze(1)
- # shape: (batch_size, num_tags, num_tags)
- next_score = broadcast_score + self.transitions + broadcast_emission
- else:
- broadcast_score = score.unsqueeze(-1)
- broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
- # shape: (batch_size, num_tags, nbest, num_tags)
- next_score = broadcast_score + self.transitions.unsqueeze(
- 1) + broadcast_emission
- # Find the top `nbest` maximum score over all possible current tag
- # shape: (batch_size, nbest, num_tags)
- next_score, indices = next_score.view(batch_size, -1,
- self.num_tags).topk(
- nbest, dim=1)
- if i == 1:
- score = score.unsqueeze(-1).expand(-1, -1, nbest)
- indices = indices * nbest
- # convert to shape: (batch_size, num_tags, nbest)
- next_score = next_score.transpose(2, 1)
- indices = indices.transpose(2, 1)
- # Set score to the next score if this timestep is valid (mask == 1)
- # and save the index that produces the next score
- # shape: (batch_size, num_tags, nbest)
- score = torch.where(mask[i].unsqueeze(-1).bool().unsqueeze(-1),
- next_score, score)
- indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1).bool(),
- indices, oor_idx)
- history_idx[i - 1] = indices
- # End transition score shape: (batch_size, num_tags, nbest)
- end_score = score + self.end_transitions.unsqueeze(-1)
- _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
- # shape: (batch_size,)
- seq_ends = mask.long().sum(dim=0) - 1
- # insert the best tag at each sequence end (last position with mask == 1)
- history_idx = history_idx.transpose(1, 0).contiguous()
- history_idx.scatter_(
- 1,
- seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
- end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
- history_idx = history_idx.transpose(1, 0).contiguous()
- # The most probable path for each sequence
- best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
- dtype=torch.long,
- device=device)
- best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
- .view(1, -1).expand(batch_size, -1)
- for idx in range(seq_length - 1, -1, -1):
- best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1,
- best_tags)
- best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest
- return torch.where(mask.unsqueeze(-1), best_tags_arr,
- oor_tag).permute(2, 1, 0)
|