Source code for mindnlp.dataset.machine_translation.iwslt2016

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
IWSLT2016 load function
"""
# pylint: disable=C0103

import os
from typing import Union, Tuple
from mindspore.dataset import IWSLT2016Dataset
from mindnlp.utils.download import cache_file
from mindnlp.dataset.register import load_dataset
from mindnlp.configs import DEFAULT_ROOT
from mindnlp.utils import untar

URL = "https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8&confirm=t"

MD5 = "c393ed3fc2a1b0f004b3331043f615ae"


SUPPORTED_DATASETS = {
    "valid_test": ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"],
    "language_pair": {
        "en": ["ar", "de", "fr", "cs"],
        "ar": ["en"],
        "fr": ["en"],
        "de": ["en"],
        "cs": ["en"],
    },
    "year": 16,
}

SET_NOT_EXISTS = {
    ("en", "ar"): [],
    ("en", "de"): [],
    ("en", "fr"): [],
    ("en", "cs"): ["tst2014"],
    ("ar", "en"): [],
    ("fr", "en"): [],
    ("de", "en"): [],
    ("cs", "en"): ["tst2014"],
}


[docs]@load_dataset.register def IWSLT2016(root: str = DEFAULT_ROOT, split: Union[Tuple[str], str] = ("train", "valid", "test"), language_pair=("de", "en"), valid_set="tst2013", test_set="tst2014", proxies=None): r""" Load the IWSLT2016 dataset The available datasets include following: **Language pairs**: +-----+-----+-----+-----+-----+-----+ | |"en" |"fr" |"de" |"cs" |"ar" | +-----+-----+-----+-----+-----+-----+ |"en" | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ |"fr" | x | | | | | +-----+-----+-----+-----+-----+-----+ |"de" | x | | | | | +-----+-----+-----+-----+-----+-----+ |"cs" | x | | | | | +-----+-----+-----+-----+-----+-----+ |"ar" | x | | | | | +-----+-----+-----+-----+-----+-----+ **valid/test sets**: ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"] Args: root (str): Directory where the datasets are saved. Default: "~/.mindnlp" split (str|Tuple[str]): Split or splits to be returned. Default:('train', 'valid', 'test'). language_pair (Tuple[str]): Tuple containing src and tgt language. Default: ('de', 'en'). valid_set (str): a string to identify validation set. Default: "tst2013". test_set (str): a string to identify test set. Default: "tst2014". proxies (dict): a dict to identify proxies,for example: {"https": "https://127.0.0.1:7890"}. Returns: - **datasets_list** (list) -A list of loaded datasets. If only one type of dataset is specified,such as 'trian', this dataset is returned instead of a list of datasets. Raises: TypeError: If `root` is not a string. TypeError: If `split` is not a string or Tuple[str]. TypeError: If `language_pair` is not a Tuple[str]. RuntimeError: If the length of `language_pair` is not 2. RuntimeError: If `language_pair` is not in the range of supported datasets. RuntimeError: If `valid_set` is not in the range of supported datasets. RuntimeError: If `test_set` is not in the range of supported datasets. Examples: >>> root = "~/.mindnlp" >>> split = ('train', 'valid', 'test') >>> language_pair = ('de', 'en') >>> valid_set="tst2013" >>> test_set="tst2014" >>> dataset_train, dataset_valid, dataset_test = IWSLT2016(root, split, \ language_pair, valid_set, test_set) >>> train_iter = dataset_train.create_dict_iterator() >>> print(next(train_iter)) {'text': Tensor(shape=[], dtype=String, value= \ 'David Gallo: Das ist Bill Lange. Ich bin Dave Gallo.'), 'translation': Tensor(shape=[], dtype=String, value= \ "David Gallo: This is Bill Lange. I'm Dave Gallo.")} """ if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError(f"language_pair must be list or tuple\ but got {type(language_pair)} instead") assert len(language_pair) == 2, \ "language_pair must contain only 2 elements: src and tgt language respectively" src_language, tgt_language = language_pair[0], language_pair[1] if src_language not in SUPPORTED_DATASETS["language_pair"]: raise ValueError( f"src_language '{src_language}' is not valid. Supported source languages are \ {list(SUPPORTED_DATASETS['language_pair'])}" ) if tgt_language not in SUPPORTED_DATASETS["language_pair"][src_language]: raise ValueError( f"tgt_language '{tgt_language}' is not valid for give src_language '\ {src_language}'. Supported target language are \ {SUPPORTED_DATASETS['language_pair'][src_language]}" ) if valid_set not in SUPPORTED_DATASETS["valid_test"] or \ valid_set in SET_NOT_EXISTS[language_pair]: support = [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]] raise ValueError( f"valid_set '{valid_set}' is not valid for given language pair \ {language_pair}. Supported validation sets are {support}" ) if test_set not in SUPPORTED_DATASETS["valid_test"] or \ test_set in SET_NOT_EXISTS[language_pair]: support = [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]] raise ValueError( f"test_set '{valid_set}' is not valid for give language pair \ {language_pair}. Supported test sets are {support}" ) if isinstance(split, str): split = split.split() if root == DEFAULT_ROOT: cache_dir = os.path.join(root, "datasets", "IWSLT2016") else: cache_dir = root file_path, _ = cache_file(None, cache_dir=cache_dir, url=URL, md5sum=MD5, download_file_name="2016-01.tgz", proxies=proxies) dataset_dir_name = untar(file_path, os.path.dirname(file_path))[0] dataset_dir_path = os.path.join(cache_dir, dataset_dir_name) untar_flie_name = language_pair[0] + '-' + language_pair[1] + '.tgz' untar_file_path = os.path.join(dataset_dir_path, 'texts', language_pair[0], language_pair[1], untar_flie_name) untar(untar_file_path, os.path.dirname(untar_file_path)) datasets_list = [] for usage in split: if usage == 'valid': dataset = IWSLT2016Dataset( dataset_dir=dataset_dir_path, usage=usage, valid_set=valid_set, shuffle=False) elif usage == 'test': dataset = IWSLT2016Dataset( dataset_dir=dataset_dir_path, usage=usage, test_set=test_set, shuffle=False) else: dataset = IWSLT2016Dataset( dataset_dir=dataset_dir_path, usage=usage, shuffle=False) datasets_list.append(dataset) if len(datasets_list) == 1: return datasets_list[0] return datasets_list