# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=W0613
"""
Logits process
"""
import math
import inspect
import logging
from typing import List, Iterable, Union, Optional, Callable, Tuple
import mindspore
from mindspore import ops
import numpy as np
[docs]class LogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
"""Torch method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
[docs]class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
"""Torch method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
[docs]class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
"""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> mindspore.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
[docs]class HammingDiversityLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for
[`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.
Args:
diversity_penalty (`float`):
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
num_beams (`int`):
Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more
details.
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
"""
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
self._num_sub_beams = num_beams // num_beam_groups
def __call__(
self,
input_ids: mindspore.Tensor,
scores: mindspore.Tensor,
current_tokens: mindspore.Tensor,
beam_group_idx: int,
) -> mindspore.Tensor:
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams
group_start_idx = beam_group_idx * self._num_sub_beams
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
group_size = group_end_idx - group_start_idx
vocab_size = scores.shape[-1]
if group_start_idx == 0:
return scores
for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx
]
token_frequency = ops.bincount(previous_group_tokens, minlength=vocab_size)
scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
return scores
[docs]class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
Args:
hallucination_penalty (`float`):
The parameter for hallucination penalty. 1.0 means no penalty.
encoder_input_ids (`torch.LongTensor`):
The encoder_input_ids that should not be repeated within the decoder ids.
"""
def __init__(self, penalty: float, encoder_input_ids: mindspore.Tensor):
if not isinstance(penalty, float) or (penalty <= 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
score = ops.gather_elements(scores, 1, self.encoder_input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = ops.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, self.encoder_input_ids, score)
return scores
[docs]class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or (penalty <= 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
score = ops.gather_elements(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = ops.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def _get_ngrams(ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].asnumpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
for hypo_idx in range(num_hypos)
]
return banned_tokens
[docs]class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
[docs]class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See
[ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350).
Args:
encoder_ngram_size (`int`):
All ngrams of size `ngram_size` can only occur within the encoder input ids.
encoder_input_ids (`int`):
The encoder_input_ids that should not be repeated within the decoder ids.
"""
def __init__(self, encoder_ngram_size: int, encoder_input_ids: mindspore.Tensor):
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
raise ValueError(
f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
)
self.ngram_size = encoder_ngram_size
if len(encoder_input_ids.shape) == 1:
encoder_input_ids = encoder_input_ids.unsqueeze(0)
self.batch_size = encoder_input_ids.shape[0]
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
# B x num_beams
num_hypos = scores.shape[0]
num_beams = num_hypos // self.batch_size
cur_len = input_ids.shape[-1]
banned_batch_tokens = [
_get_generated_ngrams(
self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
)
for hypo_idx in range(num_hypos)
]
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
[docs]class NoBadWordsLogitsProcessor(LogitsProcessor):
"""
[`LogitsProcessor`] that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
add_special_tokens=False).input_ids`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
if eos_token_id is None:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bad_words_ids = list(
filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
)
self.bad_words_id_length_1 = []
self.bad_words_id_length_greater_than_1 = []
for word in bad_words_ids:
if len(word) == 1:
self.bad_words_id_length_1.append(word[0])
else:
self.bad_words_id_length_greater_than_1.append(word)
self.static_bad_words_mask: Optional[mindspore.Tensor] = None
for banned_token_seq in self.bad_words_id_length_greater_than_1:
if len(banned_token_seq) == 0:
raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0:
self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
return scores
def _calc_static_bad_word_mask(self, scores: mindspore.Tensor) -> mindspore.Tensor:
static_bad_words_mask = ops.zeros(scores.shape[1])
static_bad_words_mask[self.bad_words_id_length_1] = 1
return static_bad_words_mask.unsqueeze(0).bool()
def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
return prev_tokens[-len(tokens):] == tokens
def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_id_length_greater_than_1:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def _set_scores_to_inf_for_banned_tokens(
self, scores: mindspore.Tensor, banned_tokens: List[List[int]]
) -> mindspore.Tensor:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
# Eliminates invalid bad word IDs that are over the vocabulary size.
if token <= scores.shape[1]:
banned_mask_list.append([idx, token])
else:
logging.error(
"An invalid bad word ID is defined: %d. This ID is not contained in the "
"vocabulary, and is therefore ignored.", token
)
if not banned_mask_list and self.static_bad_words_mask is None:
return scores
if banned_mask_list:
banned_mask = mindspore.Tensor(banned_mask_list)
indices = ops.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = (
mindspore.COOTensor(banned_mask,
indices, scores.shape)
.to_dense()
.bool()
)
if self.static_bad_words_mask is not None:
banned_mask = ops.bitwise_or(banned_mask, self.static_bad_words_mask)
else:
banned_mask = self.static_bad_words_mask
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores
[docs]class MinLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
for i in self.eos_token_id:
scores[:, i] = -float("inf")
return scores
[docs]class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
Args:
prompt_length_to_skip (`int`):
The input tokens length.
min_new_tokens (`int`):
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
for arg_name, arg_value in [
("prompt_length_to_skip", prompt_length_to_skip),
("min_new_tokens", min_new_tokens),
("eos_token_id", eos_token_id),
]:
if not isinstance(arg_value, int) or arg_value < 0:
raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
self.prompt_length_to_skip = prompt_length_to_skip
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
scores[:, self.eos_token_id] = -float("inf")
return scores
[docs]class PrefixConstrainedLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
Args:
prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`):
This function constraints the beam search to allowed tokens only at each step. This function takes 2
arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
`batch_id`.
"""
def __init__(self, prefix_allowed_tokens_fn: Callable[[int, mindspore.Tensor], List[int]], num_beams: int):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
mask = ops.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
return scores + mask
[docs]class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces the specified token as the first generated token.
Args:
bos_token_id (`int`):
The id of the token to force as the first generated token.
"""
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
cur_len = input_ids.shape[-1]
if cur_len == 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
scores[:, self.bos_token_id] = 0
return scores
[docs]class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
Args:
max_length (`int`):
The maximum length of the sequence to be generated.
eos_token_id (`Union[int, List[int]]`):
The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
list to set multiple *end-of-sequence* tokens.
"""
def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
self.max_length = max_length
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf")
for i in self.eos_token_id:
scores[:, i] = 0
return scores
[docs]class InfNanRemoveLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
the logits processor should only be used if necessary since it can slow down the generation method. `max_length` is
reached.
"""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
# set all nan values to 0.0
scores[np.isnan(scores)] = 0.0
# set all inf values to max possible value
scores[scores == float("inf")] = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).max)
return scores
[docs]class ExponentialDecayLengthPenalty(LogitsProcessor):
r"""
[`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
reached.
Args:
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty
starts and `decay_factor` represents the factor of exponential decay
eos_token_id (`Union[int, List[int]]`):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
input_ids_seq_length (`int`):
The length of the input sequence.
"""
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]],
input_ids_seq_length: int
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
for i in self.eos_token_id:
scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
return scores
[docs]class SuppressTokensLogitsProcessor(LogitsProcessor):
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
are not sampled."""
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids, scores):
scores[:, self.suppress_tokens] = -float("inf")
return scores
[docs]class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
r"""
[`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
sampled at the begining of the generation.
"""
def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids, scores):
if input_ids.shape[1] == self.begin_index:
scores[:, self.begin_suppress_tokens] = -float("inf")
return scores
[docs]class ForceTokensLogitsProcessor(LogitsProcessor):
r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
sampled at their corresponding index."""
def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map)
def __call__(self, input_ids, scores):
generation_idx = input_ids.shape[-1]
current_token = self.force_token_map.get(generation_idx, None)
if current_token is not None:
scores[:, :] = -float("inf")
scores[:, current_token] = 0
return scores
[docs]class LogitNormalization(LogitsProcessor, LogitsWarper):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
"""
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
scores = ops.log_softmax(scores, axis=-1)
return scores
[docs]class TemperatureLogitsWarper(LogitsWarper):
r"""
[`TemperatureLogitsWarper`] for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not temperature > 0:
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
scores = scores / self.temperature
return scores
[docs]class TopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
if not isinstance(min_tokens_to_keep, int) or min_tokens_to_keep < 1:
raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
sorted_logits, sorted_indices = ops.sort(scores, descending=False)
cumulative_probs = ops.softmax(sorted_logits, axis=-1).cumsum(axis=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
[docs]class TopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor:
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < mindspore.Tensor.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores