Source code for sotastream.pipelines.mtdata_pipeline

import logging
from typing import Tuple, Iterator, Union, List, Optional
import random

from sotastream.data import Line
from sotastream.augmentors import Mixer
from sotastream.filters import BitextFilter
from sotastream.pipelines import Pipeline, pipeline

logger = logging.getLogger(f"sotastream")


[docs] @pipeline("mtdata") class MTDataPipeline(Pipeline): """Pipeline to mix datasets from mtdata. To install mtdata, run `pip install mtdata`, or visit https://github.com/thammegowda/mtdata To see the list of available datasets, run `mtdata list -id -l <src>-<tgt>` where <src>-<tgt> are language pairs. Example #1: sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng Statmt-europarl-10-deu-eng Example #2: sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng Statmt-europarl-10-deu-eng --mix-weights 1 2 Example #3: sotastream mtdata -lp en-de Statmt-news_commentary-16-deu-eng,Statmt-europarl-10-deu-eng Example #1 mixes two datasets with equal weights (i.e., 1:1). Example #2 mixes two datasets with 1:2 ratio respectively. Example #3 simply concatenates both datasets separated by comma into a single dataset. Therefore, the resulting mixture weights are proportional to the number of segments in each dataset. The `--langs|-lp <src>-<tgt>` argument is used to enforce compatibility between the specified datasets and ensure correct ordering of source and target languages """ def __init__( self, data_ids: List[str], mix_weights: Optional[List[float]] = None, langs: Tuple[str, str] = None, **kwargs, ): """Initialize mtdata pipeline. :param data_ids: List of mtdata IDs :param mix_weights: Mixture weights, defaults to None (i.e., equal weights) :param langs: Tuple of source and target language codes to enforce compatibility with specified dataset ids, defaults to None (not enforced) """ if not mix_weights: mix_weights = [1.0] * len(data_ids) kwargs.pop('data_sources', None) super().__init__(mix_weights=mix_weights, data_sources=data_ids, **kwargs) assert len(data_ids) == len( self.mix_weights ), f'Expected {len(mix_weights)} weights, got {len(data_ids)}. See --mix-weights argument' random.seed(self.seed) if self.num_workers > 1: logger.warning(f'num_workers > 1 is not supported for MTData pipeline.') data_sources = [] for data_id in data_ids: dids = data_id.split(',') # allow comma-separated list of dataset IDs data_sources.append(MTDataSource(dids, langs=langs)) if len(data_sources) > 1: stream = Mixer(data_sources, self.mix_weights) else: stream = data_sources[0] self.stream = BitextFilter(stream) # removes all but fields 0 and 1
[docs] @classmethod def get_data_sources_for_argparse(cls): help_msg = '''MTData dataset IDs which are of format Group-name-version-lang1-lang2 E.g. "Statmt-news_commentary-16-deu-eng" Run "mtdata list -id -l <src>-<tgt>" to list all available dataset IDs for any <src>-<tgt> language pair. ''' return [('data_ids', help_msg, '+')]
[docs] @classmethod def get_data_sources_default_weights(cls): # we dont know how many sources will be provided until runtime CLI parsing return ['+']
[docs] @classmethod def add_cli_args(cls, parser): super().add_cli_args(parser) def LangPair(txt) -> Tuple[str, str]: """Parse language pair from CLI argument.""" pair = txt.split('-') assert len(pair) == 2, f'Expected 2 languages src-tgt, got {len(pair)}' return tuple(pair) parser.add_argument( '--langs', '-lp', required=True, metavar='SRC-TGT', type=LangPair, help='''Source and language order, e.g. "deu-eng". Ensures the correct order of the fields in the output. As per mtdata, language code 'mul' is special and meant for multilingual datasets. E.g. "mul-en" is compatible for x->en datasets, where as "en-mul" is for en->x for any x.''', )
[docs] def MTDataSource( dids: Union[str, List[str]], langs=None, progress_bar=False, ) -> Iterator[Line]: """MTData dataset iterator. :param dids: either a single dataset ID or a list of dataset ID. IDs are of form Group-name-version-lang1-lang2 e.g. "Statmt-news_commentary-16-deu-eng" :param langs: source-target language order, e.g. "deu-eng" :progress_bar: whether to show progress bar :return: Line objects """ from mtdata.data import INDEX, Cache, Parser, DatasetId from mtdata import cache_dir as CACHE_DIR, pbar_man from mtdata.iso.bcp47 import bcp47, BCP47Tag pbar_man.enabled = bool(progress_bar) if langs: # check compatibility assert len(langs) == 2, f'Expected 2 languages, got {langs}' langs = (bcp47(langs[0]), bcp47(langs[1])) data_spec = [] for did in dids: did = DatasetId.parse(did) assert did in INDEX, f'Unknown dataset ID: {did}' is_swap = False if langs: is_compat, is_swap = BCP47Tag.check_compat_swap(langs, did.langs) if not is_compat: langs_txt = '-'.join(map(str, langs)) raise ValueError(f'{did} is not compatible with {langs_txt}.') entry = INDEX[did] path = Cache(CACHE_DIR).get_entry(entry) parser = Parser(path, ext=entry.in_ext or None, ent=entry) data_spec.append([did, parser, is_swap]) count = 0 delim = '\t' while True: for did, parser, is_swap in data_spec: for rec in parser.read_segs(): if isinstance(rec, (list, tuple)): fields = [col.replace(delim, ' ').replace('\n', ' ').strip() for col in rec] else: fields = rec.split(delim) assert len(fields) >= 2, f'Expected 2 fields, got {len(fields)}' fields = fields[:2] if is_swap: fields = [fields[1], fields[0]] yield Line(fields=fields) count += 1