Source code for mindnlp.transforms.tokenizers.xlm_tokenizer

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
XLMTokenizer
"""
# pylint: disable=C0412
# pylint: disable=C0103
# pylint: disable=C0415
# pylint: disable=too-many-instance-attributes
# pylint: disable=C0103
# pylint: disable=R0913
# pylint: disable=R0914
# pylint: disable=R0912
# pylint: disable=R0915
# pylint: disable=W1505
# pylint: disable=R1705
# pylint: disable=R1723
# pylint: disable=R1702
# pylint: disable=R1724
# pylint: disable=W0102
import json
import logging
import re
from enum import Enum
import sys
import unicodedata
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
from collections.abc import Mapping, Sized
from tokenizers import AddedToken
from mindnlp.abc import PreTrainedTokenizer
from mindnlp.models.xlm.xlm_config import XLM_SUPPORT_LIST
from mindnlp.configs import MINDNLP_TOKENIZER_CONFIG_URL_BASE


PRETRAINED_VOCAB_MAP = {
    model: MINDNLP_TOKENIZER_CONFIG_URL_BASE.format('xlm', model) for model in XLM_SUPPORT_LIST
}

# class MissingType:
#     "MissingType"
#     pass
# MISSING = MissingType()
TextInput = str
PreTokenizedInput = List[str]
EncodedInput = List[int]
TextInputPair = Tuple[str, str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]
VERY_LARGE_INTEGER = int(1e30)

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "xlm-clm-ende-1024": 512,
    "xlm-mlm-en-2048": 512,
    "xlm-mlm-xnli15-1024": 512,
    "xlm-mlm-100-1280": 512,
    "xlm-mlm-enfr-1024": 512,
    "xlm-mlm-tlm-xnli15-1024": 512,
    "xlm-clm-enfr-1024": 512,
    "xlm-mlm-17-1280": 512,
    "xlm-mlm-enro-1024": 512,
}

class ExplicitEnum(str, Enum):
    """
    Enum with more explicit error message for missing values.
    """
    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid"
        )


class TensorType(ExplicitEnum):
    """
    Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
    tab-completion in an IDE.
    """

    PYTORCH = "pt"
    TENSORFLOW = "tf"
    NUMPY = "np"
    JAX = "jax"

class TruncationStrategy(ExplicitEnum):
    """
    Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
    an IDE.
    """

    ONLY_FIRST = "only_first"
    ONLY_SECOND = "only_second"
    LONGEST_FIRST = "longest_first"
    DO_NOT_TRUNCATE = "do_not_truncate"

class PaddingStrategy(ExplicitEnum):
    """
    Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
    IDE.
    """

    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"

def get_pairs(word):
    """
    Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
    strings)
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

def lowercase_and_remove_accent(text):
    """
    Lowercase and strips accents from a piece of text based on
    https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
    """
    text = " ".join(text)
    text = text.lower()
    text = unicodedata.normalize("NFD", text)
    output = []
    for char in text:
        cat = unicodedata.category(char)
        if cat == "Mn":
            continue
        output.append(char)
    return "".join(output).lower().split(" ")

def romanian_preprocessing(text):
    """Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`"""
    # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
    text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
    text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
    # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
    text = text.replace("\u0218", "S").replace("\u0219", "s")  # s-comma
    text = text.replace("\u021a", "T").replace("\u021b", "t")  # t-comma
    text = text.replace("\u0102", "A").replace("\u0103", "a")
    text = text.replace("\u00C2", "A").replace("\u00E2", "a")
    text = text.replace("\u00CE", "I").replace("\u00EE", "i")
    return text


def replace_unicode_punct(text):
    """
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
    """
    text = text.replace(",", ",")
    text = re.sub(r"。\s*", ". ", text)
    text = text.replace("、", ",")
    text = text.replace("”", '"')
    text = text.replace("“", '"')
    text = text.replace("∶", ":")
    text = text.replace(":", ":")
    text = text.replace("?", "?")
    text = text.replace("《", '"')
    text = text.replace("》", '"')
    text = text.replace(")", ")")
    text = text.replace("!", "!")
    text = text.replace("(", "(")
    text = text.replace(";", ";")
    text = text.replace("1", "1")
    text = text.replace("」", '"')
    text = text.replace("「", '"')
    text = text.replace("0", "0")
    text = text.replace("3", "3")
    text = text.replace("2", "2")
    text = text.replace("5", "5")
    text = text.replace("6", "6")
    text = text.replace("9", "9")
    text = text.replace("7", "7")
    text = text.replace("8", "8")
    text = text.replace("4", "4")
    text = re.sub(r".\s*", ". ", text)
    text = text.replace("~", "~")
    text = text.replace("’", "'")
    text = text.replace("…", "...")
    text = text.replace("━", "-")
    text = text.replace("〈", "<")
    text = text.replace("〉", ">")
    text = text.replace("【", "[")
    text = text.replace("】", "]")
    text = text.replace("%", "%")
    return text


def remove_non_printing_char(text):
    """
    Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
    """
    output = []
    for char in text:
        cat = unicodedata.category(char)
        if cat.startswith("C"):
            continue
        output.append(char)
    return "".join(output)


class Trie:
    """
    Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
    Loose reference https://en.wikipedia.org/wiki/Trie
    """

    def __init__(self):
        self.data = {}

    def add(self, word: str):
        """
        Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
        The special key `""` is used to represent termination.

        This function is idempotent, adding twice the same word will leave the trie unchanged

        Example:

        ```python
        >>> trie = Trie()
        >>> trie.add("Hello 友達")
        >>> trie.data
        {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}

        >>> trie.add("Hello")
        >>> trie.data
        {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
        ```
        """
        if not word:
            # Prevent empty string
            return
        ref = self.data
        for char in word:
            ref[char] = char in ref and ref[char] or {}
            ref = ref[char]
        ref[""] = 1

    def split(self, text: str) -> List[str]:
        "split"
        states = OrderedDict()
        offsets = [0]
        skip = 0
        for current, current_char in enumerate(text):
            if skip and current < skip:
                continue
            to_remove = set()
            reset = False
            # In this case, we already have partial matches (But unfinished)
            for start, trie_pointer in states.items():
                if "" in trie_pointer:
                    for lookstart, looktrie_pointer in states.items():
                        if lookstart > start:
                            # This partial match is later, we can stop looking
                            break
                        elif lookstart < start:
                            lookahead_index = current + 1
                            end = current + 1
                        else:
                            lookahead_index = current
                            end = current
                        next_char = text[lookahead_index] if lookahead_index < len(text) else None
                        if "" in looktrie_pointer:
                            start = lookstart
                            end = lookahead_index
                            skip = lookahead_index

                        while next_char in looktrie_pointer:
                            looktrie_pointer = looktrie_pointer[next_char]
                            lookahead_index += 1
                            if "" in looktrie_pointer:
                                start = lookstart
                                end = lookahead_index
                                skip = lookahead_index

                            if lookahead_index == len(text):
                                # End of string
                                break
                            next_char = text[lookahead_index]
                        # End lookahead

                    # Storing and resetting
                    offsets.append(start)
                    offsets.append(end)
                    reset = True
                    break
                elif current_char in trie_pointer:
                    trie_pointer = trie_pointer[current_char]
                    states[start] = trie_pointer
                else:
                    to_remove.add(start)

            if reset:
                states = {}
            else:
                for start in to_remove:
                    del states[start]

            # If this character is a starting character within the trie
            # start keeping track of this partial match.
            if current >= skip and current_char in self.data:
                states[current] = self.data[current_char]

        # We have a cut at the end with states.
        for start, trie_pointer in states.items():
            if "" in trie_pointer:
                end = len(text)
                offsets.append(start)
                offsets.append(end)
                break

        return self.cut_text(text, offsets)

    def cut_text(self, text, offsets):
        "We have all the offsets now, we just need to do the actual splitting."
        offsets.append(len(text))
        tokens = []
        start = 0
        for end in offsets:
            if start > end:
                logging.error(
                    "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
                    " anyway."
                )
                continue
            elif start == end:
                # This might happen if there's a match at index 0
                # we're also preventing zero-width cuts in case of two
                # consecutive matches
                continue
            tokens.append(text[start:end])
            start = end

        return tokens

[docs]class XLMTokenizer(PreTrainedTokenizer): """ Tokenizer used for XLM text process. Args: vocab (Vocab): Vocabulary used to look up words. return_token (bool): Whether to return token. If True: return tokens. False: return ids. Default: True. Examples: >>> from mindspore.dataset import text >>> from mindnlp.transforms import XLMTokenizer >>> text = "Believing that faith can triumph over everything is in itself the greatest belief" >>> tokenizer = XLMTokenizer.from_pretrained('xlm-clm-ende-1024') >>> tokens = tokenizer.encode(text) """ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES pretrained_vocab_map = PRETRAINED_VOCAB_MAP model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"] padding_side: str = "right" truncation_side: str = "right" slow_tokenizer_class = None def __init__( self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>", sep_token="</s>", pad_token="<pad>", cls_token="</s>", mask_token="<special1>", additional_special_tokens=[ "<special0>", "<special1>", "<special2>", "<special3>", "<special4>", "<special5>", "<special6>", "<special7>", "<special8>", "<special9>", ], lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, **kwargs, ): super().__init__( unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, lang2id=lang2id, id2lang=id2lang, do_lowercase_and_remove_accent=do_lowercase_and_remove_accent, **kwargs, ) self.added_tokens_encoder: Dict[str, int] = {} self.added_tokens_decoder: Dict[int, str] = {} self.unique_no_split_tokens: List[str] = [] self.tokens_trie = Trie() self.padding_side = kwargs.pop("padding_side", self.padding_side) model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER self.truncation_side = kwargs.pop("truncation_side", self.truncation_side) if self.truncation_side not in ["right", "left"]: raise ValueError( "Padding side should be selected between 'right' and 'left', current value: {self.truncation_side}" ) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) self.deprecation_warnings = ( {} ) self._in_target_context_manager = False try: import sacremoses except ImportError as e: raise ImportError( "You need to install sacremoses to use XLMTokenizer. " "See https://pypi.org/project/sacremoses/ for installation." )from e self.sm = sacremoses # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = {} # cache of sm.MosesTokenizer instance self.cache_moses_tokenizer = {} self.lang_with_custom_tokenizer = {"zh", "th", "ja"} # True for current supported model (v1.2.0), False for XLM-17 & 100 self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent self.lang2id = lang2id self.id2lang = id2lang if lang2id is not None and id2lang is not None: assert len(lang2id) == len(id2lang) self.ja_word_tokenizer = None self.zh_word_tokenizer = None with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} with open(merges_file, encoding="utf-8") as merges_handle: merges = merges_handle.read().split("\n")[:-1] merges = [tuple(merge.split()[:2]) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} def _convert_id_to_token(self, index: int) -> str: pass @property def do_lower_case(self): "do_lower_case" return self.do_lowercase_and_remove_accent
[docs] def moses_punct_norm(self, text, lang): "moses_punct_norm" if lang not in self.cache_moses_punct_normalizer: punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang) self.cache_moses_punct_normalizer[lang] = punct_normalizer else: punct_normalizer = self.cache_moses_punct_normalizer[lang] return punct_normalizer.normalize(text)
[docs] def moses_tokenize(self, text, lang): "moses_tokenize" if lang not in self.cache_moses_tokenizer: moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer else: moses_tokenizer = self.cache_moses_tokenizer[lang] return moses_tokenizer.tokenize(text, return_str=False, escape=False)
[docs] def moses_pipeline(self, text, lang): "moses_pipeline" text = replace_unicode_punct(text) text = self.moses_punct_norm(text, lang) text = remove_non_printing_char(text) return text
def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair_target: Optional[ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding = False, truncation = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ): # To avoid duplicating all_kwargs = { "add_special_tokens": add_special_tokens, "padding": padding, "truncation": truncation, "max_length": max_length, "stride": stride, "is_split_into_words": is_split_into_words, "pad_to_multiple_of": pad_to_multiple_of, "return_tensors": return_tensors, "return_token_type_ids": return_token_type_ids, "return_attention_mask": return_attention_mask, "return_overflowing_tokens": return_overflowing_tokens, "return_special_tokens_mask": return_special_tokens_mask, "return_offsets_mapping": return_offsets_mapping, "return_length": return_length, "verbose": verbose, } all_kwargs.update(kwargs) if text is None and text_target is None: raise ValueError("You need to specify either `text` or `text_target`.") if text is not None: encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs) if text_target is not None: target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs) if text_target is None: return encodings elif text is None: return target_encodings else: encodings["labels"] = target_encodings["input_ids"] return encodings def _call_one( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, add_special_tokens: bool = True, padding = False, truncation = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ): # Input type checking for clearer error def _is_valid_text_input(t): if isinstance(t, str): # Strings are fine return True elif isinstance(t, (list, tuple)): # List are fine as long as they are... if len(t) == 0: # ... empty return True elif isinstance(t[0], str): # ... list of strings return True elif isinstance(t[0], (list, tuple)): # ... list with an empty list or with a list of strings return len(t[0]) == 0 or isinstance(t[0][0], str) else: return False else: return False if not _is_valid_text_input(text): raise ValueError( "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " "or `List[List[str]]` (batch of pretokenized examples)." ) if text_pair is not None and not _is_valid_text_input(text_pair): raise ValueError( "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " "or `List[List[str]]` (batch of pretokenized examples)." ) if is_split_into_words: is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) else: is_batched = isinstance(text, (list, tuple)) if not is_batched: return self.encode_plus( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, **kwargs, )
[docs] def encode_plus( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding = False, truncation = None, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ): "# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'" padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) return self._encode_plus( text=text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, **kwargs, )
def _encode_plus( self, text: Union[TextInput, PreTokenizedInput, EncodedInput], text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ): def get_input_ids(text): tokens = self.tokenize_(text, **kwargs) return self.convert_tokens_to_ids(tokens) first_ids = get_input_ids(text) second_ids = get_input_ids(text_pair) if text_pair is not None else None return self.prepare_for_model( first_ids, pair_ids=second_ids, add_special_tokens=add_special_tokens, padding=padding_strategy.value, truncation=truncation_strategy.value, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, ) def _get_padding_truncation_strategies( self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs ): """ Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy and pad_to_max_length) and behaviors. """ old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate") old_pad_to_max_length = kwargs.pop("pad_to_max_length", False) # Backward compatibility for previous behavior, maybe we should deprecate it: # If you only set max_length, it activates truncation for max_length if max_length is not None and padding is False and truncation is None: if verbose: if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False): logging.warning( "Truncation was not explicitly activated but `max_length` is provided a specific value, please" " use `truncation=True` to explicitly truncate examples to max length. Defaulting to" " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the" " tokenizer you can select this strategy more precisely by providing a specific strategy to" " `truncation`." ) self.deprecation_warnings["Truncation-not-explicitly-activated"] = True truncation = "longest_first" # Get padding strategy if padding is False and old_pad_to_max_length: if verbose: logging.warn( "The `pad_to_max_length` argument is deprecated and will be removed in a future version, " "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or " "use `padding='max_length'` to pad to a max length. In this case, you can give a specific " "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the " "maximal input size of the model (e.g. 512 for Bert)." ) if max_length is None: padding_strategy = PaddingStrategy.LONGEST else: padding_strategy = PaddingStrategy.MAX_LENGTH elif padding is not False: if padding is True: if verbose: if max_length is not None and ( truncation is None or truncation is False or truncation == "do_not_truncate" ): logging.warn( "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. " "To pad to max length, use `padding='max_length'`." ) if old_pad_to_max_length is not False: logging.warn("Though `pad_to_max_length` = `True`, it is ignored because `padding`=`True`.") padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) elif isinstance(padding, PaddingStrategy): padding_strategy = padding else: padding_strategy = PaddingStrategy.DO_NOT_PAD # Get truncation strategy if truncation is None and old_truncation_strategy != "do_not_truncate": if verbose: logging.warn( "The `truncation_strategy` argument is deprecated and will be removed in a future version, use" " `truncation=True` to truncate examples to a max length. You can give a specific length with" " `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the maximal input" " size of the model (e.g. 512 for Bert). If you have pairs of inputs, you can give a specific" " truncation strategy selected among `truncation='only_first'` (will only truncate the first" " sentence in the pairs) `truncation='only_second'` (will only truncate the second sentence in the" " pairs) or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence" " in the pairs)." ) truncation_strategy = TruncationStrategy(old_truncation_strategy) elif truncation is not False and truncation is not None: if truncation is True: truncation_strategy = ( TruncationStrategy.LONGEST_FIRST ) # Default to truncate the longest sequences in pairs of inputs elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) elif isinstance(truncation, TruncationStrategy): truncation_strategy = truncation else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE # Set max length if needed if max_length is None: if padding_strategy == PaddingStrategy.MAX_LENGTH: LARGE_INTEGER = int(1e20) if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False): logging.warning( "Asking to pad to max_length but no maximum length is provided and the model has no" " predefined maximum length. Default to no padding." ) self.deprecation_warnings["Asking-to-pad-to-max_length"] = True padding_strategy = PaddingStrategy.DO_NOT_PAD else: max_length = self.model_max_length if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE: if self.model_max_length > LARGE_INTEGER: if verbose: if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False): logging.warning( "Asking to truncate to max_length but no maximum length is provided and the model has" " no predefined maximum length. Default to no truncation." ) self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE else: max_length = self.model_max_length # Test if we have a padding token if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0): raise ValueError( "Asking to pad but the tokenizer does not have a padding token. " "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` " "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`." ) # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided if ( truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and padding_strategy != PaddingStrategy.DO_NOT_PAD and pad_to_multiple_of is not None and max_length is not None and (max_length % pad_to_multiple_of != 0) ): raise ValueError( "Truncation and padding are both activated but " f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})." ) return padding_strategy, truncation_strategy, max_length, kwargs
[docs] def truncate_sequences( self, ids: List[int], pair_ids: Optional[List[int]] = None, num_tokens_to_remove: int = 0, truncation_strategy: Union[str, TruncationStrategy] = "longest_first", stride: int = 0, ) -> Tuple[List[int], List[int], List[int]]: "truncate_sequences" if num_tokens_to_remove <= 0: return ids, pair_ids, [] if not isinstance(truncation_strategy, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation_strategy) overflowing_tokens = [] if truncation_strategy == TruncationStrategy.ONLY_FIRST or ( truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None ): if len(ids) > num_tokens_to_remove: window_len = min(len(ids), stride + num_tokens_to_remove) if self.truncation_side == "left": overflowing_tokens = ids[:window_len] ids = ids[num_tokens_to_remove:] elif self.truncation_side == "right": overflowing_tokens = ids[-window_len:] ids = ids[:-num_tokens_to_remove] else: raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.") else: error_msg = ( f"We need to remove {num_tokens_to_remove} to truncate the input " f"but the first sequence has a length {len(ids)}. " ) if truncation_strategy == TruncationStrategy.ONLY_FIRST: error_msg = ( error_msg + "Please select another truncation strategy than " f"{truncation_strategy}, for instance 'longest_first' or 'only_second'." ) logging.error(error_msg) elif truncation_strategy == TruncationStrategy.LONGEST_FIRST: logging.warning( "Be aware, overflowing tokens are not returned for the setting you have chosen," " i.e. sequence pairs with the" "truncation strategy. So the returned list will always be empty even if some " "tokens have been removed." ) for _ in range(num_tokens_to_remove): if pair_ids is None or len(ids) > len(pair_ids): if self.truncation_side == "right": ids = ids[:-1] elif self.truncation_side == "left": ids = ids[1:] else: raise ValueError("invalid truncation strategy:" + str(self.truncation_side)) else: if self.truncation_side == "right": pair_ids = pair_ids[:-1] elif self.truncation_side == "left": pair_ids = pair_ids[1:] else: raise ValueError("invalid truncation strategy:" + str(self.truncation_side)) elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: if len(pair_ids) > num_tokens_to_remove: window_len = min(len(pair_ids), stride + num_tokens_to_remove) if self.truncation_side == "right": overflowing_tokens = pair_ids[-window_len:] pair_ids = pair_ids[:-num_tokens_to_remove] elif self.truncation_side == "left": overflowing_tokens = pair_ids[:window_len] pair_ids = pair_ids[num_tokens_to_remove:] else: raise ValueError("invalid truncation strategy:" + str(self.truncation_side)) else: logging.error( "We need to remove {num_tokens_to_remove} to truncate the input " "but the second sequence has a length {len(pair_ids)}. " "Please select another truncation strategy than {truncation_strategy}, " "for instance 'longest_first' or 'only_first'." ) return (ids, pair_ids, overflowing_tokens)
[docs] def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: "build_inputs_with_special_tokens" bos = [self.bos_token_id] sep = [self.sep_token_id] if token_ids_1 is None: return bos + token_ids_0 + sep return bos + token_ids_0 + sep + token_ids_1 + sep
[docs] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence pair mask has the following format: ``` 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | ``` If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (`List[int]`): List of IDs. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
[docs] def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: "get_special_tokens_mask" assert already_has_special_tokens and token_ids_1 is None, ( "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " "Please use a slow (full python) tokenizer to activate this argument. " "Or set `return_special_tokens_mask=True` when calling the encoding method " "to get the special tokens mask in any tokenizer. " ) all_special_ids = self.all_special_ids # cache the property special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] return special_tokens_mask
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool): """ Depending on the input and internal state we might trigger a warning about a sequence that is too long for its corresponding model Args: ids (`List[str]`): The ids produced by the tokenization max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set) verbose (`bool`): Whether or not to print more information and warnings. """ if max_length is None and len(ids) > self.model_max_length and verbose: if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): logging.warning( "Token indices sequence length is longer than the specified maximum sequence length " "for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model " "will result in indexing errors" ) self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
[docs] def prepare_for_model( self, ids: List[int], pair_ids: Optional[List[int]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ): "Backward compatibility for 'truncation_strategy', 'pad_to_max_length'" padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( padding=padding, truncation=truncation, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, verbose=verbose, **kwargs, ) pair = bool(pair_ids is not None) len_ids = len(ids) len_pair_ids = len(pair_ids) if pair else 0 if return_token_type_ids and not add_special_tokens: raise ValueError( "Asking to return token_type_ids while setting add_special_tokens to False " "results in an undefined behavior. Please set add_special_tokens to True or " "set return_token_type_ids to None." ) if ( return_overflowing_tokens and truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is not None ): raise ValueError( "Not possible to return overflowing tokens for pair of sequences with the " "`longest_first`. Please select another truncation strategy than `longest_first`, " "for instance `only_second` or `only_first`." ) # Load from model defaults if return_token_type_ids is None: return_token_type_ids = "token_type_ids" in self.model_input_names if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names encoded_inputs = {} # Compute the total size of the returned encodings total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) # Truncation: Handle max sequence length overflowing_tokens = [] if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: ids, pair_ids, overflowing_tokens = self.truncate_sequences( ids, pair_ids=pair_ids, num_tokens_to_remove=total_len - max_length, truncation_strategy=truncation_strategy, stride=stride, ) if return_overflowing_tokens: encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["num_truncated_tokens"] = total_len - max_length # Add special tokens if add_special_tokens: sequence = self.build_inputs_with_special_tokens(ids, pair_ids) token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) else: sequence = ids + pair_ids if pair else ids token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) # Build output dictionary encoded_inputs["input_ids"] = sequence if return_token_type_ids: encoded_inputs["token_type_ids"] = token_type_ids if return_special_tokens_mask: if add_special_tokens: encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) else: encoded_inputs["special_tokens_mask"] = [0] * len(sequence) # Padding if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: encoded_inputs = self.pad( encoded_inputs, max_length=max_length, padding=padding_strategy.value, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) if return_length: encoded_inputs["length"] = len(encoded_inputs["input_ids"]) return encoded_inputs
[docs] def bpe(self, token): "bpe" word = tuple(token[:-1]) + (token[-1] + "</w>",) if token in self.cache: return self.cache[token] pairs = get_pairs(word) if not pairs: return token + "</w>" while True: bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) except ValueError: new_word.extend(word[i:]) break else: new_word.extend(word[i:j]) i = j if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = " ".join(word) if word == "\n </w>": word = "\n</w>" self.cache[token] = word return word
[docs] def tokenize_(self, text: TextInput, **kwargs) -> List[str]: "# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors" all_special_tokens_extended = dict( (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken) ) if kwargs: logging.warning("Keyword arguments {kwargs} not recognized.") if hasattr(self, "do_lower_case") and self.do_lower_case: # convert non-special tokens to lowercase escaped_special_toks = [ re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens) ] pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) no_split_token = set(self.unique_no_split_tokens) tokens = self.tokens_trie.split(text) # ["This is something", "<special_token_1>", " else"] for i, token in enumerate(tokens): if token in no_split_token: tok_extended = all_special_tokens_extended.get(token, None) left = tokens[i - 1] if i > 0 else None right = tokens[i + 1] if i < len(tokens) - 1 else None if isinstance(tok_extended, AddedToken): if tok_extended.rstrip and right: # A bit counter-intuitive but we strip the left of the string # since tok_extended.rstrip means the special token is eating all white spaces on its right tokens[i + 1] = right.lstrip() # Strip white spaces on the left if tok_extended.lstrip and left: tokens[i - 1] = left.rstrip() # Opposite here else: # We strip left and right by default if right: tokens[i + 1] = right.lstrip() if left: tokens[i - 1] = left.rstrip() # ["This is something", "<special_token_1>", "else"] tokenized_text = [] for token in tokens: # Need to skip eventual empty (fully stripped) tokens if not token: continue if token in no_split_token: tokenized_text.append(token) else: tokenized_text.extend(self._tokenize(token)) # ["This", " is", " something", "<special_token_1>", "else"] return tokenized_text
def _tokenize(self, text, lang="en", bypass_tokenizer=False): if lang and self.lang2id and lang not in self.lang2id: logging.error( "Supplied language code not found in lang2id mapping. Please check that your language is supported by" " the loaded pretrained model." ) if bypass_tokenizer: text = text.split() elif lang not in self.lang_with_custom_tokenizer: text = self.moses_pipeline(text, lang=lang) if lang == "ro": text = romanian_preprocessing(text) text = self.moses_tokenize(text, lang=lang) elif lang == "zh": try: if "jieba" not in sys.modules: import jieba else: jieba = sys.modules["jieba"] except (AttributeError, ImportError): logging.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logging.error("1. pip install jieba") raise text = " ".join(jieba.cut(text)) text = self.moses_pipeline(text, lang=lang) text = text.split() elif lang == "ja": text = self.moses_pipeline(text, lang=lang) text = text.split() else: raise ValueError("It should not reach here") if self.do_lowercase_and_remove_accent and not bypass_tokenizer: text = lowercase_and_remove_accent(text) split_tokens = [] for token in text: if token: split_tokens.extend(list(self.bpe(token).split(" "))) return split_tokens
[docs] def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]: """ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the vocabulary. Args: tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s). Returns: `int` or `List[int]`: The token id or list of token ids. """ if tokens is None: return None if isinstance(tokens, str): return self._convert_token_to_id_with_added_voc(tokens) ids = [] for token in tokens: ids.append(self._convert_token_to_id_with_added_voc(token)) return ids
def _convert_token_to_id_with_added_voc(self, token): if token is None: return None if token in self.added_tokens_encoder: return self.added_tokens_encoder[token] return self._convert_token_to_id(token) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.encoder.get(token, self.encoder.get(self.unk_token))
[docs] def num_special_tokens_to_add(self, pair: bool = False) -> int: "num_special_tokens_to_add" token_ids_0 = [] token_ids_1 = [] return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
[docs] def pad( self, encoded_inputs, padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, verbose: bool = True, ): "pad" if self.__class__.__name__.endswith("Fast"): if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False): logging.warning( "You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer," " using the `__call__` method is faster than using a method to encode the text followed by" " a call to the `pad` method to get a padded encoding." ) self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True # If we have a list of dicts, let's convert it in a dict of lists # We do this to allow using this method as a collate_fn function in PyTorch Dataloader if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping): encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} # The model's main input name, usually `input_ids`, has be passed for padding if self.model_input_names[0] not in encoded_inputs: raise ValueError( "You should supply an encoding or a list of encodings to this method " "that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}" ) required_input = encoded_inputs[self.model_input_names[0]] if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0): if return_attention_mask: encoded_inputs["attention_mask"] = [] return encoded_inputs # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects # and rebuild them afterwards if no return_tensors is specified # Note that we lose the specific device the tensor may be on for PyTorch first_element = required_input[0] if isinstance(first_element, (list, tuple)): # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element. for item in required_input: if len(item) != 0: first_element = item[0] break # Convert padding_strategy in PaddingStrategy padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies( padding=padding, max_length=max_length, verbose=verbose ) required_input = encoded_inputs[self.model_input_names[0]] if required_input and not isinstance(required_input[0], (list, tuple)): encoded_inputs = self._pad( encoded_inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) return encoded_inputs['input_ids'] batch_size = len(required_input) assert all( len(v) == batch_size for v in encoded_inputs.values() ), "Some items in the output dictionary have a different batch size than others." if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) padding_strategy = PaddingStrategy.MAX_LENGTH batch_outputs = {} for i in range(batch_size): inputs = {k: v[i] for k, v in encoded_inputs.items()} outputs = self._pad( inputs, max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) for key, value in outputs.items(): if key not in batch_outputs: batch_outputs[key] = [] batch_outputs[key].append(value) return encoded_inputs['input_ids']
def _pad( self, encoded_inputs, max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, ) -> dict: # Load from model defaults if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names required_input = encoded_inputs[self.model_input_names[0]] if padding_strategy == PaddingStrategy.LONGEST: max_length = len(required_input) if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length # Initialize attention mask if not present. if return_attention_mask and "attention_mask" not in encoded_inputs: encoded_inputs["attention_mask"] = [1] * len(required_input) if needs_to_be_padded: difference = max_length - len(required_input) if self.padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = ( encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference ) if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference elif self.padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ "token_type_ids" ] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input else: raise ValueError("Invalid padding strategy:" + str(self.padding_side)) return encoded_inputs