Source code for mindnlp.dataset.machine_translation.iwslt2017

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
IWSLT2017 load function
"""
# pylint: disable=C0103

import os
from typing import Union, Tuple
from mindspore.dataset import IWSLT2017Dataset
from mindnlp.transforms import BasicTokenizer
from mindnlp.utils.download import cache_file
from mindnlp.dataset.process import common_process
from mindnlp.dataset.register import load_dataset, process
from mindnlp.configs import DEFAULT_ROOT
from mindnlp.utils import untar

URL = "https://drive.google.com/u/0/uc?id=12ycYSzLIG253AFN35Y6qoyf9wtkOjakp&confirm=t"

MD5 = "aca701032b1c4411afc4d9fa367796ba"


SUPPORTED_DATASETS = {
    "valid_test": ["dev2010", "tst2010"],
    "language_pair": {
        "en": ["nl", "de", "it", "ro"],
        "ro": ["de", "en", "nl", "it"],
        "de": ["ro", "en", "nl", "it"],
        "it": ["en", "nl", "de", "ro"],
        "nl": ["de", "en", "it", "ro"],
    },
    "year": 17,
}


[docs]@load_dataset.register def IWSLT2017(root: str = DEFAULT_ROOT, split: Union[Tuple[str], str] = ("train", "valid", "test"), language_pair=("de", "en"), proxies=None): r""" Load the IWSLT2017 dataset The available datasets include following: **Language pairs**: +-----+-----+-----+-----+-----+-----+ | |"en" |"nl" |"de" |"it" |"ro" | +-----+-----+-----+-----+-----+-----+ |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ |"nl" | x | | x | x | x | +-----+-----+-----+-----+-----+-----+ |"de" | x | x | | x | x | +-----+-----+-----+-----+-----+-----+ |"it" | x | x | x | | x | +-----+-----+-----+-----+-----+-----+ |"ro" | x | x | x | x | | +-----+-----+-----+-----+-----+-----+ Args: root (str): Directory where the datasets are saved. Default: "~/.mindnlp" split (str|Tuple[str]): Split or splits to be returned. Default:('train', 'valid', 'test'). language_pair (Tuple[str]): Tuple containing src and tgt language. Default: ('de', 'en'). proxies (dict): a dict to identify proxies,for example: {"https": "https://127.0.0.1:7890"}. Returns: - **datasets_list** (list) -A list of loaded datasets. If only one type of dataset is specified,such as 'trian', this dataset is returned instead of a list of datasets. Raises: TypeError: If `root` is not a string. TypeError: If `split` is not a string or Tuple[str]. TypeError: If `language_pair` is not a Tuple[str]. RuntimeError: If the length of `language_pair` is not 2. RuntimeError: If `language_pair` is not in the range of supported datasets. Examples: >>> root = "~/.mindnlp" >>> split = ('train', 'valid', 'test') >>> language_pair = ('de', 'en') >>> dataset_train, dataset_valid, dataset_test = IWSLT2017(root, split, language_pair) >>> train_iter = dataset_train.create_dict_iterator() >>> print(next(train_iter)) {'text': Tensor(shape=[], dtype=String, value= 'Vielen Dank, Chris.'), 'translation': Tensor(shape=[], dtype=String, value= 'Thank you so much, Chris.')} """ if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError(f"language_pair must be list or tuple but got \ {type(language_pair)} instead") assert len(language_pair) == 2, \ "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError( f"src_language '{src_language}' is not valid. Supported source languages are \ {list(SUPPORTED_DATASETS['language_pair'])}" ) if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError( f"tgt_language '{tgt_language}' is not valid for give src_language '\ {src_language}'. Supported target language are \ {SUPPORTED_DATASETS['language_pair'][src_language]}" ) if isinstance(split, str): split = split.split() if root == DEFAULT_ROOT: cache_dir = os.path.join(root, "datasets", "IWSLT2017") else: cache_dir = root file_path, _ = cache_file(None, cache_dir=cache_dir, url=URL, md5sum=MD5, download_file_name="2017-01-trnmted.tgz", proxies=proxies) dataset_dir_name = untar(file_path, os.path.dirname(file_path))[0] dataset_dir_path = os.path.join(cache_dir, dataset_dir_name) untar_flie_name = "DeEnItNlRo-DeEnItNlRo.tgz" untar_file_path = os.path.join(dataset_dir_path, 'texts', 'DeEnItNlRo', 'DeEnItNlRo', untar_flie_name) untar(untar_file_path, os.path.dirname(untar_file_path)) datasets_list = [] for usage in split: dataset = IWSLT2017Dataset( dataset_dir=dataset_dir_path, usage=usage, shuffle=False) datasets_list.append(dataset) if len(datasets_list) == 1: return datasets_list[0] return datasets_list
[docs]@process.register def IWSLT2017_Process(dataset, column = 'translation', tokenizer = BasicTokenizer(), vocab=None): ''' a function transforms specific language column in IWSLT2017 dataset into tensors Args: dataset (GeneratorDataset, ZipDataset): IWSLT2017 dataset column (str): The language column name in IWSLT2017 tokenizer (TextTensorOperation): Tokenizer you what to used vocab (Vocab): The vocab you use, defaults to None. If None, a new vocab will be created. Returns: - MapDataset, dataset after process. - Vocab, new vocab created from dataset if 'vocab' is None. Raises: TypeError: If `language` is not string. Examples: >>> from mindspore.dataset import text >>> from mindnlp.dataset import IWSLT2017, IWSLT2017_Process >>> test_dataset = IWSLT2017( >>> root='./dataset', >>> split="test", >>> language_pair=("de", "en") >>> ) >>> test_dataset, vocab = process('IWSLT2017', test_dataset, "translation", >>> text.BasicTokenizer()) >>> for i in test_dataset.create_tuple_iterator(): >>> print(i) >>> break ''' return common_process(dataset, column, tokenizer, vocab)