Source code for mindnlp.abc.mixins.cell_mixin

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Cell mixin
"""

from typing import Tuple
from typing import Optional
import mindspore
import numpy as np
from mindspore import ops
from mindspore import Tensor
from mindnlp._legacy.functional import arange

[docs]class CellUtilMixin: """ A few utilities to be used as a mixin. """ @property def dtype(self) -> mindspore.dtype: """ `mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ return mindspore.float32 @staticmethod def create_extended_attention_mask_for_decoder(input_shape, attention_mask): """create_extended_attention_mask_for_decoder""" batch_size, seq_length = input_shape seq_ids = arange(seq_length) # causal_mask = ops.tile((seq_ids[None, None, :]).astype(mindspore.int32),\ # (batch_size, seq_length, 1)) <= seq_ids[None, :, None] # mindspore 2.0 causal_mask = Tensor(np.tile(seq_ids[None, None, :].asnumpy(), (batch_size, seq_length, 1))) \ <= seq_ids[None, :, None] # in case past_key_values are used we need to add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.astype(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] causal_mask = ops.concat( [ ops.ones((batch_size, seq_length, prefix_seq_len), causal_mask.dtype), causal_mask, ], axis=-1, ) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] return extended_attention_mask def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: """ Invert an attention mask (e.g., switches 0. and 1.). Args: encoder_attention_mask (`mindspore.Tensor`): An attention mask. Returns: `mindspore.Tensor`: The inverted attention mask. """ if encoder_attention_mask.ndim == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if encoder_attention_mask.ndim == 2: encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] else: encoder_extended_attention_mask = encoder_attention_mask # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow # /transformer/transformer_layers.py#L270 # encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask.transpose(-1, -2)) encoder_extended_attention_mask = encoder_extended_attention_mask.astype(dtype=self.dtype) # fp16 compatibility encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) \ * Tensor(np.finfo(mindspore.dtype_to_nptype(self.dtype)).min) return encoder_extended_attention_mask def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int],dtype = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (`Tuple[int]`): The shape of the input to the model. Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ if dtype is None: dtype = self.dtype # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.ndim == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.ndim == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable # to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder: extended_attention_mask = CellUtilMixin.create_extended_attention_mask_for_decoder( input_shape, attention_mask ) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.astype(dtype=dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) \ * Tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min) return extended_attention_mask def get_head_mask( self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False ) -> Tensor: """ Prepare the head mask if needed. Args: head_mask (`mindspore.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). num_hidden_layers (`int`): The number of hidden layers in the model. is_attention_chunked: (`bool`, *optional*, defaults to `False`): Whether or not the attentions scores are computed by chunks or not. Returns: `mindspore.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with `[None]` for each layer. """ if head_mask is not None: head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) if is_attention_chunked is True: head_mask = head_mask.expand_dims(-1) else: head_mask = () for _ in range(num_hidden_layers): head_mask += (None,) return head_mask def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" if head_mask.dim() == 1: head_mask = head_mask.expand_dims(0).expand_dims(0).expand_dims(-1).expand_dims(-1) head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.expand_dims(1).expand_dims(-1)\ .expand_dims(-1) # We can specify head_mask for each layer assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" head_mask = head_mask.astype(dtype=self.dtype) # switch to float if need + fp16 compatibility return head_mask