# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Cell mixin
"""
from typing import Tuple
from typing import Optional
import mindspore
import numpy as np
from mindspore import ops
from mindspore import Tensor
from mindnlp._legacy.functional import arange
[docs]class CellUtilMixin:
"""
A few utilities to be used as a mixin.
"""
@property
def dtype(self) -> mindspore.dtype:
"""
`mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return mindspore.float32
@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
"""create_extended_attention_mask_for_decoder"""
batch_size, seq_length = input_shape
seq_ids = arange(seq_length)
# causal_mask = ops.tile((seq_ids[None, None, :]).astype(mindspore.int32),\
# (batch_size, seq_length, 1)) <= seq_ids[None, :, None] # mindspore 2.0
causal_mask = Tensor(np.tile(seq_ids[None, None, :].asnumpy(), (batch_size, seq_length, 1))) \
<= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.astype(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = ops.concat(
[
ops.ones((batch_size, seq_length, prefix_seq_len), causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""
Invert an attention mask (e.g., switches 0. and 1.).
Args:
encoder_attention_mask (`mindspore.Tensor`): An attention mask.
Returns:
`mindspore.Tensor`: The inverted attention mask.
"""
if encoder_attention_mask.ndim == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.ndim == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else:
encoder_extended_attention_mask = encoder_attention_mask
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.astype(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) \
* Tensor(np.finfo(mindspore.dtype_to_nptype(self.dtype)).min)
return encoder_extended_attention_mask
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int],dtype = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.ndim == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.ndim == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable
# to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
extended_attention_mask = CellUtilMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask
)
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.astype(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) \
* Tensor(np.finfo(mindspore.dtype_to_nptype(dtype)).min)
return extended_attention_mask
def get_head_mask(
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
) -> Tensor:
"""
Prepare the head mask if needed.
Args:
head_mask (`mindspore.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
is_attention_chunked: (`bool`, *optional*, defaults to `False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
`mindspore.Tensor` with shape `[num_hidden_layers x batch x
num_heads x seq_length x seq_length]` or list with
`[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.expand_dims(-1)
else:
head_mask = ()
for _ in range(num_hidden_layers):
head_mask += (None,)
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.expand_dims(0).expand_dims(0).expand_dims(-1).expand_dims(-1)
head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.expand_dims(1).expand_dims(-1)\
.expand_dims(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.astype(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask