Source code for mindnlp.modules.crf

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=C0412

"""crf module"""

import mindspore
from mindspore import nn, ops, Tensor
from mindspore import Parameter
from mindspore.common.initializer import initializer, Uniform
from mindnlp.utils import less_min_pynative_first
if less_min_pynative_first:
    from mindnlp._legacy.functional import full, arange, where
else:
    from mindspore.ops import full, arange, where

[docs]def sequence_mask(seq_length, max_length, batch_first=False): """generate mask matrix by seq_length""" range_vector = arange(0, max_length, 1, dtype=seq_length.dtype) result = range_vector < seq_length.view(seq_length.shape + (1,)) if batch_first: return result return result.swapaxes(0, 1)
[docs]class CRF(nn.Cell): """Conditional random field. This module implements a conditional random field [LMP01]_. The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has `~CRF.decode` method which finds the best tag sequence given an emission score tensor using `Viterbi algorithm`_. Args: num_tags: Number of tags. batch_first: Whether the first dimension corresponds to the size of a minibatch. reduction: Specifies the reduction to apply to the output: ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. ``sum``: the output will be summed over batches. ``mean``: the output will be averaged over batches. ``token_mean``: the output will be averaged over tokens. Attributes: start_transitions (`~Parameter`): Start transition score tensor of size ``(num_tags,)``. end_transitions (`~Parameter`): End transition score tensor of size ``(num_tags,)``. transitions (`~Parameter`): Transition score tensor of size ``(num_tags, num_tags)``. .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). "Conditional random fields: Probabilistic models for segmenting and labeling sequence data". *Proc. 18th International Conf. on Machine Learning*. Morgan Kaufmann. pp. 282–289. .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm """ def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None: super().__init__() if num_tags <= 0: raise ValueError(f'invalid number of tags: {num_tags}') if reduction not in ('none', 'sum', 'mean', 'token_mean'): raise ValueError(f'invalid reduction: {reduction}') self.num_tags = num_tags self.batch_first = batch_first self.reduction = reduction self.start_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions') self.end_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions') self.transitions = Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions') def __repr__(self) -> str: return f'{self.__class__.__name__}(num_tags={self.num_tags})'
[docs] def construct(self, emissions, tags=None, seq_length=None): if tags is None: return self._decode(emissions, seq_length) return self._construct(emissions, tags, seq_length)
def _construct(self, emissions, tags=None, seq_length=None): if self.batch_first: batch_size, max_length = tags.shape emissions = emissions.swapaxes(0, 1) tags = tags.swapaxes(0, 1) else: max_length, batch_size = tags.shape if seq_length is None: seq_length = full((batch_size,), max_length, dtype=mindspore.int64) mask = sequence_mask(seq_length, max_length) # shape: (batch_size,) numerator = self._compute_score(emissions, tags, seq_length-1, mask) # shape: (batch_size,) denominator = self._compute_normalizer(emissions, mask) # shape: (batch_size,) llh = denominator - numerator if self.reduction == 'none': return llh if self.reduction == 'sum': return llh.sum() if self.reduction == 'mean': return llh.mean() return llh.sum() / mask.astype(emissions.dtype).sum() def _decode(self, emissions, seq_length=None): """Find the most likely tag sequence using Viterbi algorithm. Args: emissions (`~torch.Tensor`): Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise. mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. Returns: List of list containing the best tag sequence for each batch. """ # self._validate(emissions, mask=mask) if self.batch_first: batch_size, max_length = emissions.shape[:2] emissions = emissions.swapaxes(0, 1) else: max_length, batch_size = emissions.shape[:2] if seq_length is None: seq_length = full((batch_size,), max_length, dtype=mindspore.int64) mask = sequence_mask(seq_length, max_length) return self._viterbi_decode(emissions, mask) def _compute_score(self, emissions, tags, seq_ends, mask): # emissions: (seq_length, batch_size, num_tags) # tags: (seq_length, batch_size) # mask: (seq_length, batch_size) seq_length, batch_size = tags.shape mask = mask.astype(emissions.dtype) # Start transition score and first emission # shape: (batch_size,) score = self.start_transitions[tags[0]] indices = ops.stack([ops.zeros((batch_size,), mindspore.int64), arange(batch_size), tags[0]]) # score += emissions[0, arange(batch_size), tags[0]] score += ops.gather_nd(emissions, indices.T) i = Tensor(1, mindspore.int64) while i < seq_length: # for i in range(1, seq_length): # Transition score to next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) t_indices = ops.stack([tags[i - 1], tags[i]]) # score += self.transitions[tags[i - 1], tags[i]] * mask[i] score += ops.gather_nd(self.transitions, t_indices.T) * mask[i] # Emission score for next tag, only added if next timestep is valid (mask == 1) # shape: (batch_size,) e_indices = ops.stack([ops.tile(i, (batch_size,)), arange(batch_size), tags[i]]) score += ops.gather_nd(emissions, e_indices.T) * mask[i] i += 1 # End transition score # shape: (batch_size,) tag_indices = ops.stack([seq_ends, arange(batch_size)]) # last_tags = tags[seq_ends, arange(batch_size)] last_tags = ops.gather_nd(tags, tag_indices.T) # shape: (batch_size,) score += self.end_transitions[last_tags] return score def _compute_normalizer(self, emissions, mask): # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) seq_length = emissions.shape[0] mask = mask.astype(emissions.dtype) # Start transition score and first emission; score has size of # (batch_size, num_tags) where for each batch, the j-th column stores # the score that the first timestep has tag j # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] i = Tensor(1, mindspore.int32) while i < seq_length: # for i in range(1, seq_length): # Broadcast score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.expand_dims(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emissions = emissions[i].expand_dims(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the sum of scores of all # possible tag sequences so far that end with transitioning from tag i to tag j # and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emissions # Sum over all possible current tags, but we're in score space, so a sum # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of # all possible tag sequences so far, that end in tag i # shape: (batch_size, num_tags) next_score = ops.logsumexp(next_score, axis=1) # Set score to the next score if this timestep is valid (mask == 1) # shape: (batch_size, num_tags) score = where(mask[i].astype(mindspore.bool_).expand_dims(1), next_score, score) i += 1 # End transition score # shape: (batch_size, num_tags) score += self.end_transitions # Sum (log-sum-exp) over all possible tags # shape: (batch_size,) return ops.logsumexp(score, axis=1) def _viterbi_decode(self, emissions, mask): # emissions: (seq_length, batch_size, num_tags) # mask: (seq_length, batch_size) seq_length = mask.shape[0] # Start transition and first emission # shape: (batch_size, num_tags) score = self.start_transitions + emissions[0] history = () # score is a tensor of size (batch_size, num_tags) where for every batch, # value at column j stores the score of the best tag sequence so far that ends # with tag j # history saves where the best tags candidate transitioned from; this is used # when we trace back the best tag sequence # Viterbi algorithm recursive case: we compute the score of the best tag sequence # for every possible next tag for i in range(1, seq_length): # Broadcast viterbi score for every possible next tag # shape: (batch_size, num_tags, 1) broadcast_score = score.expand_dims(2) # Broadcast emission score for every possible current tag # shape: (batch_size, 1, num_tags) broadcast_emission = emissions[i].expand_dims(1) # Compute the score tensor of size (batch_size, num_tags, num_tags) where # for each sample, entry at row i and column j stores the score of the best # tag sequence so far that ends with transitioning from tag i to tag j and emitting # shape: (batch_size, num_tags, num_tags) next_score = broadcast_score + self.transitions + broadcast_emission # Find the maximum score over all possible current tag # shape: (batch_size, num_tags) indices = next_score.argmax(axis=1) next_score = next_score.max(axis=1) # Set score to the next score if this timestep is valid (mask == 1) # and save the index that produces the next score # shape: (batch_size, num_tags) score = where(mask[i].expand_dims(1), next_score, score) history += (indices,) # End transition score # shape: (batch_size, num_tags) score += self.end_transitions return score, history
[docs] @staticmethod def post_decode(score, history, seq_length): """Trace back the best tag sequence based on the score and history tensors.""" # Now, compute the best path for each sample batch_size = seq_length.shape[0] seq_ends = seq_length - 1 # shape: (batch_size,) best_tags_list = [] for idx in range(batch_size): # Find the tag which maximizes the score at the last timestep; this is our best tag # for the last timestep best_last_tag = score[idx].argmax(axis=0) best_tags = [best_last_tag] # We trace back where the best last tag comes from, append that to our best tag # sequence, and trace it back again, and so on for hist in reversed(history[:seq_ends[idx]]): best_last_tag = hist[idx][best_tags[-1]] best_tags.append(best_last_tag) # Reverse the order because we start from the last timestep best_tags.reverse() best_tags_list.append(best_tags) return best_tags_list
__all__ = ["CRF", "sequence_mask"]