Source code for mindnlp.abc.configs.pretrained_config

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# pylint: disable=C0103
"""
Pretrained config.
"""

import copy
import json
import os
from typing import Optional, Tuple, Dict
from mindspore import log as logger

from mindnlp.configs import HF_CONFIG_URL_BASE, DEFAULT_ROOT
from mindnlp.utils.download import cached_path

[docs]class PreTrainedConfig: """ Abstract class for Pretrained models config. """ def __init__(self, **kwargs): self.finetuning_task = kwargs.pop('finetuning_task', None) self.num_labels = kwargs.pop('num_labels', 2) self.output_attentions = kwargs.pop('output_attentions', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.is_decoder = kwargs.pop("is_decoder", False) self.pad_token_id = kwargs.pop("pad_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.add_cross_attention = kwargs.pop("add_cross_attention", False) self.tie_word_embeddings = kwargs.pop( "tie_word_embeddings", True ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) self.return_dict = kwargs.pop("return_dict", False) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) self.pruned_heads = kwargs.pop("pruned_heads", {}) self.problem_type = kwargs.pop("problem_type", None) allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") if self.problem_type is not None and self.problem_type not in allowed_problem_types: raise ValueError( f"The config parameter `problem_type` was not understood: received {self.problem_type} " "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." ) pretrained_config_archive_map: Dict[str, str] = {} @classmethod def from_json(cls, file_path): """load config from json.""" with open(file_path, "r", encoding="utf-8") as file: text = file.read() config_map = json.loads(text) config = cls() for key, value in config_map.items(): setattr(config, key, value) return config @classmethod def from_json_file(cls, json_file): """Constructs a `Config` from a json file of parameters.""" with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() dict_obj = json.loads(text) return cls(**dict_obj) @classmethod def load(cls, pretrained_model_name_or_path): """load config.""" return cls.from_pretrained(pretrained_model_name_or_path) @property def use_return_dict(self) -> bool: """ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples. """ # If torchscript is set, force `return_dict=False` to avoid jit errors return self.return_dict @classmethod def from_dict(cls, config_dict: Dict, **kwargs) -> "PreTrainedConfig": """ Constructs a `Config` from a Python dictionary of parameters. Args: config_dict (:obj:`Dict[str, any]`): Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` method. kwargs (:obj:`Dict[str, any]`): Additional parameters from which to initialize the configuration object. Returns: :class:`PretrainedConfig`: An instance of a configuration object """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) config = cls(**config_dict) if hasattr(config, "pruned_heads"): config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): setattr(config, key, value) to_remove.append(key) for key in to_remove: kwargs.pop(key, None) logger.info("Model config %s", str(config)) if return_unused_kwargs: return config, kwargs return config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PreTrainedConfig": """from_pretrained""" config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs) @classmethod def get_config_dict( cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs ) -> Tuple[Dict, Dict]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a Config using `from_dict`. Parameters: pretrained_model_name_or_path (:obj:`string`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict: A map of `shortcut names` to `url`. By default, will use the current class attribute. Returns: :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. """ cache_dir = kwargs.pop("cache_dir", os.path.join(DEFAULT_ROOT, 'models')) _ = kwargs.pop("force_download", False) _ = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) _ = kwargs.pop("local_files_only", False) from_pt = kwargs.pop("from_pt", False) if pretrained_config_archive_map is None: pretrained_config_archive_map = cls.pretrained_config_archive_map if pretrained_model_name_or_path in pretrained_config_archive_map: config_file = pretrained_config_archive_map[pretrained_model_name_or_path] cache_dir = os.path.join(cache_dir, pretrained_model_name_or_path) elif os.path.isdir(pretrained_model_name_or_path): config_file = "config.json" cache_dir = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path cache_dir = None elif from_pt: config_file = HF_CONFIG_URL_BASE.format(pretrained_model_name_or_path) cache_dir = os.path.join(cache_dir, pretrained_model_name_or_path) else: raise ValueError(f'not found config of {pretrained_model_name_or_path}') try: # Load from URL or cache if already cached resolved_config_file = cached_path( config_file, cache_dir=cache_dir, proxies=proxies) # Load config dict if resolved_config_file is None: raise EnvironmentError config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError as exc: if pretrained_model_name_or_path in pretrained_config_archive_map: msg = f"Couldn't reach server at '{config_file}' to download pretrained model configuration file." else: msg = ( f"Can't load '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' " f"is a correct model identifier listed on 'https://download.mindspore.cn/toolkits/mindnlp/models'\n\n" f"- or '{pretrained_model_name_or_path}' " f"is the correct path to a directory containing a config.json file\n\n" ) raise EnvironmentError(msg) from exc except json.JSONDecodeError as exc: msg = ( f"Couldn't reach server at '{config_file}' to download configuration file or " f"configuration file is not a valid JSON file. " f"Please check network or file content here: {resolved_config_file}." ) raise EnvironmentError(msg) from exc if resolved_config_file == config_file: logger.info("loading configuration file %s", config_file) else: logger.info("loading configuration file %s from cache at %s", config_file, resolved_config_file) return config_dict, kwargs @classmethod def _dict_from_json_file(cls, json_file: str): """_dict_from_json_file""" with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" def to_file(self, save_path): """Serializes this instance to a JSON file.""" output_dict = self.to_dict() with open(os.path.join(save_path, 'config.json'), encoding='utf-8') as f: json.dump(output_dict, f, sort_keys=True, indent=2)