# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
XLMTokenizer
"""
# pylint: disable=C0412
# pylint: disable=C0103
# pylint: disable=C0415
# pylint: disable=too-many-instance-attributes
# pylint: disable=C0103
# pylint: disable=R0913
# pylint: disable=R0914
# pylint: disable=R0912
# pylint: disable=R0915
# pylint: disable=W1505
# pylint: disable=R1705
# pylint: disable=R1723
# pylint: disable=R1702
# pylint: disable=R1724
# pylint: disable=W0102
import json
import logging
import re
from enum import Enum
import sys
import unicodedata
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
from collections.abc import Mapping, Sized
from tokenizers import AddedToken
from mindnlp.abc import PreTrainedTokenizer
from mindnlp.models.xlm.xlm_config import XLM_SUPPORT_LIST
from mindnlp.configs import MINDNLP_TOKENIZER_CONFIG_URL_BASE
PRETRAINED_VOCAB_MAP = {
model: MINDNLP_TOKENIZER_CONFIG_URL_BASE.format('xlm', model) for model in XLM_SUPPORT_LIST
}
# class MissingType:
# "MissingType"
# pass
# MISSING = MissingType()
TextInput = str
PreTokenizedInput = List[str]
EncodedInput = List[int]
TextInputPair = Tuple[str, str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
VERY_LARGE_INTEGER = int(1e30)
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"xlm-clm-ende-1024": 512,
"xlm-mlm-en-2048": 512,
"xlm-mlm-xnli15-1024": 512,
"xlm-mlm-100-1280": 512,
"xlm-mlm-enfr-1024": 512,
"xlm-mlm-tlm-xnli15-1024": 512,
"xlm-clm-enfr-1024": 512,
"xlm-mlm-17-1280": 512,
"xlm-mlm-enro-1024": 512,
}
class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid"
)
class TensorType(ExplicitEnum):
"""
Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
tab-completion in an IDE.
"""
PYTORCH = "pt"
TENSORFLOW = "tf"
NUMPY = "np"
JAX = "jax"
class TruncationStrategy(ExplicitEnum):
"""
Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
an IDE.
"""
ONLY_FIRST = "only_first"
ONLY_SECOND = "only_second"
LONGEST_FIRST = "longest_first"
DO_NOT_TRUNCATE = "do_not_truncate"
class PaddingStrategy(ExplicitEnum):
"""
Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
IDE.
"""
LONGEST = "longest"
MAX_LENGTH = "max_length"
DO_NOT_PAD = "do_not_pad"
def get_pairs(word):
"""
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def lowercase_and_remove_accent(text):
"""
Lowercase and strips accents from a piece of text based on
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
"""
text = " ".join(text)
text = text.lower()
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output).lower().split(" ")
def romanian_preprocessing(text):
"""Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`"""
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") # s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") # t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
return text
def replace_unicode_punct(text):
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
"""
text = text.replace(",", ",")
text = re.sub(r"。\s*", ". ", text)
text = text.replace("、", ",")
text = text.replace("”", '"')
text = text.replace("“", '"')
text = text.replace("∶", ":")
text = text.replace(":", ":")
text = text.replace("?", "?")
text = text.replace("《", '"')
text = text.replace("》", '"')
text = text.replace(")", ")")
text = text.replace("!", "!")
text = text.replace("(", "(")
text = text.replace(";", ";")
text = text.replace("1", "1")
text = text.replace("」", '"')
text = text.replace("「", '"')
text = text.replace("0", "0")
text = text.replace("3", "3")
text = text.replace("2", "2")
text = text.replace("5", "5")
text = text.replace("6", "6")
text = text.replace("9", "9")
text = text.replace("7", "7")
text = text.replace("8", "8")
text = text.replace("4", "4")
text = re.sub(r".\s*", ". ", text)
text = text.replace("~", "~")
text = text.replace("’", "'")
text = text.replace("…", "...")
text = text.replace("━", "-")
text = text.replace("〈", "<")
text = text.replace("〉", ">")
text = text.replace("【", "[")
text = text.replace("】", "]")
text = text.replace("%", "%")
return text
def remove_non_printing_char(text):
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
"""
output = []
for char in text:
cat = unicodedata.category(char)
if cat.startswith("C"):
continue
output.append(char)
return "".join(output)
class Trie:
"""
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
Loose reference https://en.wikipedia.org/wiki/Trie
"""
def __init__(self):
self.data = {}
def add(self, word: str):
"""
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
The special key `""` is used to represent termination.
This function is idempotent, adding twice the same word will leave the trie unchanged
Example:
```python
>>> trie = Trie()
>>> trie.add("Hello 友達")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
>>> trie.add("Hello")
>>> trie.data
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
```
"""
if not word:
# Prevent empty string
return
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref = ref[char]
ref[""] = 1
def split(self, text: str) -> List[str]:
"split"
states = OrderedDict()
offsets = [0]
skip = 0
for current, current_char in enumerate(text):
if skip and current < skip:
continue
to_remove = set()
reset = False
# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if "" in trie_pointer:
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
lookahead_index = current + 1
end = current + 1
else:
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
break
elif current_char in trie_pointer:
trie_pointer = trie_pointer[current_char]
states[start] = trie_pointer
else:
to_remove.add(start)
if reset:
states = {}
else:
for start in to_remove:
del states[start]
# If this character is a starting character within the trie
# start keeping track of this partial match.
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]
# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
end = len(text)
offsets.append(start)
offsets.append(end)
break
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
"We have all the offsets now, we just need to do the actual splitting."
offsets.append(len(text))
tokens = []
start = 0
for end in offsets:
if start > end:
logging.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
" anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens.append(text[start:end])
start = end
return tokens
[docs]class XLMTokenizer(PreTrainedTokenizer):
"""
Tokenizer used for XLM text process.
Args:
vocab (Vocab): Vocabulary used to look up words.
return_token (bool): Whether to return token. If True: return tokens. False: return ids. Default: True.
Examples:
>>> from mindspore.dataset import text
>>> from mindnlp.transforms import XLMTokenizer
>>> text = "Believing that faith can triumph over everything is in itself the greatest belief"
>>> tokenizer = XLMTokenizer.from_pretrained('xlm-clm-ende-1024')
>>> tokens = tokenizer.encode(text)
"""
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_map = PRETRAINED_VOCAB_MAP
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
padding_side: str = "right"
truncation_side: str = "right"
slow_tokenizer_class = None
def __init__(
self,
vocab_file,
merges_file,
unk_token="<unk>",
bos_token="<s>",
sep_token="</s>",
pad_token="<pad>",
cls_token="</s>",
mask_token="<special1>",
additional_special_tokens=[
"<special0>",
"<special1>",
"<special2>",
"<special3>",
"<special4>",
"<special5>",
"<special6>",
"<special7>",
"<special8>",
"<special9>",
],
lang2id=None,
id2lang=None,
do_lowercase_and_remove_accent=True,
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
lang2id=lang2id,
id2lang=id2lang,
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
**kwargs,
)
self.added_tokens_encoder: Dict[str, int] = {}
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
self.tokens_trie = Trie()
self.padding_side = kwargs.pop("padding_side", self.padding_side)
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
self.truncation_side = kwargs.pop("truncation_side", self.truncation_side)
if self.truncation_side not in ["right", "left"]:
raise ValueError(
"Padding side should be selected between 'right' and 'left', current value: {self.truncation_side}"
)
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
self.deprecation_warnings = (
{}
)
self._in_target_context_manager = False
try:
import sacremoses
except ImportError as e:
raise ImportError(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)from e
self.sm = sacremoses
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = {}
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = {}
self.lang_with_custom_tokenizer = {"zh", "th", "ja"}
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id
self.id2lang = id2lang
if lang2id is not None and id2lang is not None:
assert len(lang2id) == len(id2lang)
self.ja_word_tokenizer = None
self.zh_word_tokenizer = None
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
merges = merges_handle.read().split("\n")[:-1]
merges = [tuple(merge.split()[:2]) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
def _convert_id_to_token(self, index: int) -> str:
pass
@property
def do_lower_case(self):
"do_lower_case"
return self.do_lowercase_and_remove_accent
[docs] def moses_punct_norm(self, text, lang):
"moses_punct_norm"
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
return punct_normalizer.normalize(text)
[docs] def moses_tokenize(self, text, lang):
"moses_tokenize"
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
[docs] def moses_pipeline(self, text, lang):
"moses_pipeline"
text = replace_unicode_punct(text)
text = self.moses_punct_norm(text, lang)
text = remove_non_printing_char(text)
return text
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding = False,
truncation = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
# To avoid duplicating
all_kwargs = {
"add_special_tokens": add_special_tokens,
"padding": padding,
"truncation": truncation,
"max_length": max_length,
"stride": stride,
"is_split_into_words": is_split_into_words,
"pad_to_multiple_of": pad_to_multiple_of,
"return_tensors": return_tensors,
"return_token_type_ids": return_token_type_ids,
"return_attention_mask": return_attention_mask,
"return_overflowing_tokens": return_overflowing_tokens,
"return_special_tokens_mask": return_special_tokens_mask,
"return_offsets_mapping": return_offsets_mapping,
"return_length": return_length,
"verbose": verbose,
}
all_kwargs.update(kwargs)
if text is None and text_target is None:
raise ValueError("You need to specify either `text` or `text_target`.")
if text is not None:
encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
if text_target is not None:
target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
if text_target is None:
return encodings
elif text is None:
return target_encodings
else:
encodings["labels"] = target_encodings["input_ids"]
return encodings
def _call_one(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding = False,
truncation = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
# Input type checking for clearer error
def _is_valid_text_input(t):
if isinstance(t, str):
# Strings are fine
return True
elif isinstance(t, (list, tuple)):
# List are fine as long as they are...
if len(t) == 0:
# ... empty
return True
elif isinstance(t[0], str):
# ... list of strings
return True
elif isinstance(t[0], (list, tuple)):
# ... list with an empty list or with a list of strings
return len(t[0]) == 0 or isinstance(t[0][0], str)
else:
return False
else:
return False
if not _is_valid_text_input(text):
raise ValueError(
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
if text_pair is not None and not _is_valid_text_input(text_pair):
raise ValueError(
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
"or `List[List[str]]` (batch of pretokenized examples)."
)
if is_split_into_words:
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
else:
is_batched = isinstance(text, (list, tuple))
if not is_batched:
return self.encode_plus(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
**kwargs,
)
[docs] def encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding = False,
truncation = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
"# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'"
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._encode_plus(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
**kwargs,
)
def _encode_plus(
self,
text: Union[TextInput, PreTokenizedInput, EncodedInput],
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
):
def get_input_ids(text):
tokens = self.tokenize_(text, **kwargs)
return self.convert_tokens_to_ids(tokens)
first_ids = get_input_ids(text)
second_ids = get_input_ids(text_pair) if text_pair is not None else None
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
add_special_tokens=add_special_tokens,
padding=padding_strategy.value,
truncation=truncation_strategy.value,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
)
def _get_padding_truncation_strategies(
self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
):
"""
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
and pad_to_max_length) and behaviors.
"""
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
# Backward compatibility for previous behavior, maybe we should deprecate it:
# If you only set max_length, it activates truncation for max_length
if max_length is not None and padding is False and truncation is None:
if verbose:
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
logging.warning(
"Truncation was not explicitly activated but `max_length` is provided a specific value, please"
" use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
" 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
" tokenizer you can select this strategy more precisely by providing a specific strategy to"
" `truncation`."
)
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
truncation = "longest_first"
# Get padding strategy
if padding is False and old_pad_to_max_length:
if verbose:
logging.warn(
"The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
"use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
"use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
"length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
"maximal input size of the model (e.g. 512 for Bert)."
)
if max_length is None:
padding_strategy = PaddingStrategy.LONGEST
else:
padding_strategy = PaddingStrategy.MAX_LENGTH
elif padding is not False:
if padding is True:
if verbose:
if max_length is not None and (
truncation is None or truncation is False or truncation == "do_not_truncate"
):
logging.warn(
"`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
"To pad to max length, use `padding='max_length'`."
)
if old_pad_to_max_length is not False:
logging.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.")
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Get truncation strategy
if truncation is None and old_truncation_strategy != "do_not_truncate":
if verbose:
logging.warn(
"The `truncation_strategy` argument is deprecated and will be removed in a future version, use"
" `truncation=True` to truncate examples to a max length. You can give a specific length with"
" `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input"
" size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific"
" truncation strategy selected among `truncation='only_first'` (will only truncate the first"
" sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the"
" pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence"
" in the pairs)."
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
elif truncation is not False and truncation is not None:
if truncation is True:
truncation_strategy = (
TruncationStrategy.LONGEST_FIRST
) # Default to truncate the longest sequences in pairs of inputs
elif not isinstance(truncation, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation)
elif isinstance(truncation, TruncationStrategy):
truncation_strategy = truncation
else:
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
LARGE_INTEGER = int(1e20)
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
logging.warning(
"Asking to pad to max_length but no maximum length is provided and the model has no"
" predefined maximum length. Default to no padding."
)
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
padding_strategy = PaddingStrategy.DO_NOT_PAD
else:
max_length = self.model_max_length
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
if self.model_max_length > LARGE_INTEGER:
if verbose:
if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
logging.warning(
"Asking to truncate to max_length but no maximum length is provided and the model has"
" no predefined maximum length. Default to no truncation."
)
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
else:
max_length = self.model_max_length
# Test if we have a padding token
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
raise ValueError(
"Asking to pad but the tokenizer does not have a padding token. "
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
)
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
if (
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
and padding_strategy != PaddingStrategy.DO_NOT_PAD
and pad_to_multiple_of is not None
and max_length is not None
and (max_length % pad_to_multiple_of != 0)
):
raise ValueError(
"Truncation and padding are both activated but "
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
)
return padding_strategy, truncation_strategy, max_length, kwargs
[docs] def truncate_sequences(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
num_tokens_to_remove: int = 0,
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
stride: int = 0,
) -> Tuple[List[int], List[int], List[int]]:
"truncate_sequences"
if num_tokens_to_remove <= 0:
return ids, pair_ids, []
if not isinstance(truncation_strategy, TruncationStrategy):
truncation_strategy = TruncationStrategy(truncation_strategy)
overflowing_tokens = []
if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
):
if len(ids) > num_tokens_to_remove:
window_len = min(len(ids), stride + num_tokens_to_remove)
if self.truncation_side == "left":
overflowing_tokens = ids[:window_len]
ids = ids[num_tokens_to_remove:]
elif self.truncation_side == "right":
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
else:
raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
else:
error_msg = (
f"We need to remove {num_tokens_to_remove} to truncate the input "
f"but the first sequence has a length {len(ids)}. "
)
if truncation_strategy == TruncationStrategy.ONLY_FIRST:
error_msg = (
error_msg + "Please select another truncation strategy than "
f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
)
logging.error(error_msg)
elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
logging.warning(
"Be aware, overflowing tokens are not returned for the setting you have chosen,"
" i.e. sequence pairs with the"
"truncation strategy. So the returned list will always be empty even if some "
"tokens have been removed."
)
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if self.truncation_side == "right":
ids = ids[:-1]
elif self.truncation_side == "left":
ids = ids[1:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
else:
if self.truncation_side == "right":
pair_ids = pair_ids[:-1]
elif self.truncation_side == "left":
pair_ids = pair_ids[1:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
if len(pair_ids) > num_tokens_to_remove:
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
if self.truncation_side == "right":
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove]
elif self.truncation_side == "left":
overflowing_tokens = pair_ids[:window_len]
pair_ids = pair_ids[num_tokens_to_remove:]
else:
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
else:
logging.error(
"We need to remove {num_tokens_to_remove} to truncate the input "
"but the second sequence has a length {len(pair_ids)}. "
"Please select another truncation strategy than {truncation_strategy}, "
"for instance 'longest_first' or 'only_first'."
)
return (ids, pair_ids, overflowing_tokens)
[docs] def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
[docs] def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"get_special_tokens_mask"
assert already_has_special_tokens and token_ids_1 is None, (
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
"Please use a slow (full python) tokenizer to activate this argument. "
"Or set `return_special_tokens_mask=True` when calling the encoding method "
"to get the special tokens mask in any tokenizer. "
)
all_special_ids = self.all_special_ids # cache the property
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
return special_tokens_mask
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
"""
Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
corresponding model
Args:
ids (`List[str]`): The ids produced by the tokenization
max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set)
verbose (`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logging.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
"will result in indexing errors"
)
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
[docs] def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
):
"Backward compatibility for 'truncation_strategy', 'pad_to_max_length'"
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
pair = bool(pair_ids is not None)
len_ids = len(ids)
len_pair_ids = len(pair_ids) if pair else 0
if return_token_type_ids and not add_special_tokens:
raise ValueError(
"Asking to return token_type_ids while setting add_special_tokens to False "
"results in an undefined behavior. Please set add_special_tokens to True or "
"set return_token_type_ids to None."
)
if (
return_overflowing_tokens
and truncation_strategy == TruncationStrategy.LONGEST_FIRST
and pair_ids is not None
):
raise ValueError(
"Not possible to return overflowing tokens for pair of sequences with the "
"`longest_first`. Please select another truncation strategy than `longest_first`, "
"for instance `only_second` or `only_first`."
)
# Load from model defaults
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
encoded_inputs = {}
# Compute the total size of the returned encodings
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
# Truncation: Handle max sequence length
overflowing_tokens = []
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
ids,
pair_ids=pair_ids,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
stride=stride,
)
if return_overflowing_tokens:
encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length
# Add special tokens
if add_special_tokens:
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
else:
sequence = ids + pair_ids if pair else ids
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
# Build output dictionary
encoded_inputs["input_ids"] = sequence
if return_token_type_ids:
encoded_inputs["token_type_ids"] = token_type_ids
if return_special_tokens_mask:
if add_special_tokens:
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
else:
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
# Padding
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
max_length=max_length,
padding=padding_strategy.value,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
if return_length:
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
return encoded_inputs
[docs] def bpe(self, token):
"bpe"
word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
if word == "\n </w>":
word = "\n</w>"
self.cache[token] = word
return word
[docs] def tokenize_(self, text: TextInput, **kwargs) -> List[str]:
"# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors"
all_special_tokens_extended = dict(
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
)
if kwargs:
logging.warning("Keyword arguments {kwargs} not recognized.")
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
# ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens):
if token in no_split_token:
tok_extended = all_special_tokens_extended.get(token, None)
left = tokens[i - 1] if i > 0 else None
right = tokens[i + 1] if i < len(tokens) - 1 else None
if isinstance(tok_extended, AddedToken):
if tok_extended.rstrip and right:
# A bit counter-intuitive but we strip the left of the string
# since tok_extended.rstrip means the special token is eating all white spaces on its right
tokens[i + 1] = right.lstrip()
# Strip white spaces on the left
if tok_extended.lstrip and left:
tokens[i - 1] = left.rstrip() # Opposite here
else:
# We strip left and right by default
if right:
tokens[i + 1] = right.lstrip()
if left:
tokens[i - 1] = left.rstrip()
# ["This is something", "<special_token_1>", "else"]
tokenized_text = []
for token in tokens:
# Need to skip eventual empty (fully stripped) tokens
if not token:
continue
if token in no_split_token:
tokenized_text.append(token)
else:
tokenized_text.extend(self._tokenize(token))
# ["This", " is", " something", "<special_token_1>", "else"]
return tokenized_text
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
if lang and self.lang2id and lang not in self.lang2id:
logging.error(
"Supplied language code not found in lang2id mapping. Please check that your language is supported by"
" the loaded pretrained model."
)
if bypass_tokenizer:
text = text.split()
elif lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang)
if lang == "ro":
text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang)
elif lang == "zh":
try:
if "jieba" not in sys.modules:
import jieba
else:
jieba = sys.modules["jieba"]
except (AttributeError, ImportError):
logging.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logging.error("1. pip install jieba")
raise
text = " ".join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang)
text = text.split()
elif lang == "ja":
text = self.moses_pipeline(text, lang=lang)
text = text.split()
else:
raise ValueError("It should not reach here")
if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
text = lowercase_and_remove_accent(text)
split_tokens = []
for token in text:
if token:
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
[docs] def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
"""
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
vocabulary.
Args:
tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
Returns:
`int` or `List[int]`: The token id or list of token ids.
"""
if tokens is None:
return None
if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids
def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
[docs] def num_special_tokens_to_add(self, pair: bool = False) -> int:
"num_special_tokens_to_add"
token_ids_0 = []
token_ids_1 = []
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
[docs] def pad(
self,
encoded_inputs,
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
verbose: bool = True,
):
"pad"
if self.__class__.__name__.endswith("Fast"):
if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False):
logging.warning(
"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer,"
" using the `__call__` method is faster than using a method to encode the text followed by"
" a call to the `pad` method to get a padded encoding."
)
self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
# The model's main input name, usually `input_ids`, has be passed for padding
if self.model_input_names[0] not in encoded_inputs:
raise ValueError(
"You should supply an encoding or a list of encodings to this method "
"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
)
required_input = encoded_inputs[self.model_input_names[0]]
if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
if return_attention_mask:
encoded_inputs["attention_mask"] = []
return encoded_inputs
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
for item in required_input:
if len(item) != 0:
first_element = item[0]
break
# Convert padding_strategy in PaddingStrategy
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
padding=padding, max_length=max_length, verbose=verbose
)
required_input = encoded_inputs[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
encoded_inputs = self._pad(
encoded_inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return encoded_inputs['input_ids']
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = {k: v[i] for k, v in encoded_inputs.items()}
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return encoded_inputs['input_ids']
def _pad(
self,
encoded_inputs,
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
# Load from model defaults
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
required_input = encoded_inputs[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if return_attention_mask and "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * len(required_input)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if return_attention_mask:
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
)
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
elif self.padding_side == "left":
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
]
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
return encoded_inputs