Source code for mindnlp.workflow.work

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The meta classs of work in Workflow.
"""

# pylint: disable=no-member

import abc
import math
import os
from abc import abstractmethod

from mindnlp.configs import DEFAULT_ROOT
from mindnlp.utils import cache_file
from mindnlp.workflow.utils import cut_chinese_sent


[docs]class Work(metaclass=abc.ABCMeta): """ The meta classs of work in Workflow. The meta class has the five abstract function, the subclass need to inherit from the meta class. Args: work(string): The name of work. model(string): The model name in the work. kwargs (dict, optional): Additional keyword arguments passed along to the specific work. """ def __init__(self, model, work, **kwargs): self.model = model self.work = work self.kwargs = kwargs self._usage = "" self._model = None self._init_class = None # The root directory for storing Workflow related files, default to ~/.mindnlp. self._home_path = ( self.kwargs["home_path"] if "home_path" in self.kwargs else DEFAULT_ROOT ) self._work_flag = ( self.kwargs["work_flag"] if "work_flag" in self.kwargs else self.model ) self.from_hf_hub = kwargs.pop("from_hf_hub", False) if "work_path" in self.kwargs: self._work_path = self.kwargs["work_path"] self._custom_model = True else: self._work_path = os.path.join( self._home_path, "workflow", self.work, self.model ) if not self.from_hf_hub: pass @abstractmethod def _construct_model(self, model): """ Construct the inference model for the predictor. """ @abstractmethod def _construct_tokenizer(self, model): """ Construct the tokenizer for the predictor. """ @abstractmethod def _preprocess(self, inputs, padding=True, add_special_tokens=True): """ Transform the raw text to the model inputs, two steps involved: 1) Transform the raw text to token ids. 2) Generate the other model inputs from the raw text and token ids. """ @abstractmethod def _run_model(self, inputs): """ Run the work model from the outputs of the `_tokenize` function. """ @abstractmethod def _postprocess(self, inputs): """ The model output is the logits and pros, this function will convert the model output to raw text. """ def _get_graph_model_name(self): """ Get the graph model name. """ return "" def _check_work_files(self): """ Check files required by the work. """ for file_id, file_name in self.resource_files_names.items(): path = os.path.join(self._work_path, file_id, file_name) cache_dir = os.path.join(self._work_path, file_id) url = self.resource_files_urls[self.model][file_id][0] md5 = self.resource_files_urls[self.model][file_id][1] downloaded = True if not os.path.exists(path=path): downloaded = False if not downloaded: cache_file(filename=file_name, cache_dir=cache_dir, url=url, md5sum=md5) def _check_predictor_type(self): """ Check the predictor type. """ def _prepare_graph_mode(self): """ Construct the input data and predictor in the MindSpore graph mode. """ def _prepare_onnx_mode(self): """ Prepare the onnx model. """ def _get_inference_model(self): """ Return the inference program, inputs and outputs in graph mode. """ def _check_input_text(self, inputs): """ Check whether the input text meet the requirement. """ inputs = inputs[0] if isinstance(inputs, str): if len(inputs) == 0: raise ValueError( "Invalid inputs, input text should not be empty text, \ please check your input." ) inputs = [inputs] elif isinstance(inputs, list): if not (isinstance(inputs[0], str) and len(inputs[0].strip()) > 0): raise TypeError( "Invalid inputs, input text should be list of str, \ and first element of list should not be empty text." ) else: raise TypeError( f"Invalid inputs, input text should be str or list of str, \ but type of {type(inputs)} found!" ) return inputs def _auto_splitter(self, input_texts, max_text_len, split_sentence=False): """ Split the raw texts automatically for model inference. Args: input_texts (List[str]): input raw texts. max_text_len (int): cutting length. split_sentence (bool): If True, sentence-level split will be performed. `split_sentence` will be set to False if bbox_list is not None since sentence-level split is not support for document. Return: short_input_texts (List[str]): the short input texts for model inference. input_mapping (dict): mapping between raw text and short input texts. """ input_mapping = {} short_input_texts = [] cnt_org = 0 cnt_short = 0 for _, text in enumerate(input_texts): if not split_sentence: sens = [text] else: sens = cut_chinese_sent(text) for sen in sens: lens = len(sen) if lens <= max_text_len: short_input_texts.append(sen) input_mapping.setdefault(cnt_org, []).append(cnt_short) cnt_short += 1 else: temp_text_list = [ sen[i : i + max_text_len] for i in range(0, lens, max_text_len) ] short_input_texts.extend(temp_text_list) short_idx = cnt_short cnt_short += math.ceil(lens / max_text_len) temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)] input_mapping.setdefault(cnt_org, []).extend(temp_text_id) cnt_org += 1 return short_input_texts, input_mapping def _auto_joiner(self, short_results, short_inputs, input_mapping, is_dict=False): """ Join the short results automatically and generate the final results to match with the user inputs. """
[docs] def help(self): """ Return the usage message of the current work. """ print(f"Examples:\n{self._usage}")
def __call__(self, *args): inputs = self._preprocess(*args) outputs = self._run_model(inputs) results = self._postprocess(outputs) return results