Source code for sotastream.cli

#!/usr/bin/env python3

import sys

# This might have to do with functioning on mounted Azure blobs
sys.dont_write_bytecode = True

import argparse
import logging
import json
import os
import time

from collections import defaultdict
from multiprocessing import Pipe, Process
from typing import Type

from . import __version__, Defaults
from .utils.split import split_file_into_chunks
from .pipelines import Pipeline, PIPELINES

# Use seed in logger for when multiple are running
logger = logging.getLogger(f"sotastream")

USER = os.environ.get('USER', os.environ.get('USERNAME', 'nouser'))


[docs] def adjustSeed(seed, local_num_instances, local_instance_rank): """ Adjust seed for infinibatch such that each instance gets a different one based on process number and MPI coordinates. """ if seed == 0: seed = round(time.time() * 1000) # the current time in milliseconds mpi_num_instances = 1 mpi_instance_rank = 0 # these variables are set automatically by mpirun when used inside an MPI world if "OMPI_COMM_WORLD_SIZE" in os.environ: mpi_num_instances = int(os.environ["OMPI_COMM_WORLD_SIZE"]) mpi_instance_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) # hash-combine seed with local process number and rank and MPI process number and rank hashed_seed = hash((seed, local_num_instances, local_instance_rank, mpi_num_instances, mpi_instance_rank)) logger.info( f"Computed seed {hashed_seed} from original seed {seed} and instance info: ({local_num_instances}, {local_instance_rank}, {mpi_num_instances}, {mpi_instance_rank})" ) return hashed_seed
[docs] def run_pipeline_process(conn, args, seed, worker_id, num_workers): """ Runs a pipeline in a single subprocess. Each subprocess writes to the pipe (conn) after it has seen the specified number (args.queue_buffer_size) of lines. """ kwargs = {k: v for k, v in vars(args).items() if not (k in ["pipeline", "seed"])} # These environment variables are used in the subprocesses to determine which worker they are os.environ["SOTASTREAM_WORKER_ID"] = str(worker_id) os.environ["SOTASTREAM_WORKER_COUNT"] = str(num_workers) pipeline = Pipeline.create(args.pipeline, seed=seed, **kwargs) try: lines = [] for line in pipeline: lines.append(str(line)) if len(lines) >= min(args.queue_buffer_size, args.buffer_size): conn.send(lines) lines = [] if lines: conn.send(lines) finally: conn.close()
[docs] def add_global_args(parser: argparse.ArgumentParser): """ Add global arguments to the parser. These appear before the pipeline argument and are available to all pipelines. :param parser: The parser to add the options to. """ parser.add_argument( "--log-rate", "-lr", type=int, default=0, metavar="N", help="Log every Nth instance (0=off)" ) parser.add_argument( "--log-first", "-lf", type=int, default=5, metavar="N", help="Log first N instances (default: %(default)s)", ) parser.add_argument("--sample-file", type=argparse.FileType("tw"), help="Where to log samples") parser.add_argument( '--buffer-size', '-b', help='Number of lines infinibatch will load into memory', type=int, default=Defaults.BUFFER_SIZE, ) parser.add_argument( '--queue-buffer-size', '-q', help='Queue buffer size', type=int, default=Defaults.QUEUE_BUFFER_SIZE, ) parser.add_argument( '--seed', '-s', help='Random seed (default 0 uses time for initialization)', type=int, default=Defaults.SEED, ) parser.add_argument( '--num-processes', '-n', help='Number of processes to use for better throughput', type=int, default=Defaults.NUM_PROCESSES, ) parser.add_argument('--version', '-V', action='version', version='sotastream {}'.format(__version__)) parser.add_argument( "--split-tmpdir", default=f"/tmp/sotastream-{USER}", help="Base temporary directory to use when splitting data files", ) parser.add_argument("--quiet", action="store_true", help="Suppress logging output")
[docs] def maybe_split_files(args): """Split data files into smaller files in a temporary directory This function updates args inplace: it replaces .gz paths (if any) with split dirs. Args: args: CLI args object from argparse """ # Look up the class for the pipeline, and get the named list of arguments PipelineClass: Type['Pipeline'] = PIPELINES[args.pipeline] args_dict = vars(args) data_source_params = PipelineClass.get_data_sources_for_argparse() # Use the name to get the path from the runtime args object data_sources = [(x[0], args_dict[x[0]]) for x in data_source_params] for name, path in data_sources: # For any path that is a .gz file, split it into chunks. # Directories that were pre-split are left as-is. if not isinstance(path, str): logger.warning(f"Skipping {name}={path} because it is {type(path)}, but str expected") continue if not os.path.isdir(path) and path.endswith(".gz"): splitdir = split_file_into_chunks(path, tmpdir=args.split_tmpdir, split_size=args.buffer_size) setattr(args, name, splitdir) # Inject a keyword argument 'data_sources' that contains all data sources setattr(args, 'data_sources', [path for name, path in data_sources])
[docs] def main(): stats = defaultdict(int) stats['start_time'] = time.time() # Get the list of available pipelines parser = argparse.ArgumentParser( prog='sotastream', description='Command line wrapper for augmentation pipelines', formatter_class=argparse.RawTextHelpFormatter, epilog='''\n\nTo load additional pipelines create (or symlink) *_pipeline.py files from current directory.''', ) add_global_args(parser) # Each pipeline is a different subcommand with its own arguments. sub_parsers = parser.add_subparsers( dest='pipeline', required=True, metavar="pipeline", help="The pipeline to run. Available pipelines:\n- " + "\n- ".join(sorted(PIPELINES.keys())), ) for pipeline_name, pipeline_class in PIPELINES.items(): # Create a sub-parser and add the pipeline's arguments to it. sub_parser = sub_parsers.add_parser( pipeline_name, description=pipeline_class.__doc__, formatter_class=argparse.RawTextHelpFormatter ) pipeline_class.add_cli_args(sub_parser) args = parser.parse_args() logLevel = logging.CRITICAL if args.quiet else logging.INFO logging.basicConfig(level=logLevel) maybe_split_files(args) N = args.num_processes pipes = [Pipe() for i in range(N)] processes = [ Process(target=run_pipeline_process, args=(pipes[i][1], args, adjustSeed(args.seed, N, i), i, N)) for i in range(N) ] for p in processes: p.start() overhead_time = time.time() lineno = 0 num_fields = defaultdict(int) try: # round-robin across the pipes forever while True: for pipe in pipes: # To avoid pickling (and the associated timing costs), lines # are transmitted as strings, not Line objects. lines = pipe[0].recv() for line in lines: fields = line.split("\t") num_fields[len(fields)] += 1 print(line) lineno += 1 if (args.log_rate > 0 and lineno % args.log_rate == 0) or lineno <= args.log_first: if args.sample_file: print(line, file=args.sample_file) else: logger.info(f"SAMPLE {lineno}: {line}") except BrokenPipeError: # this is not really an error, just means that the receiving process has ended # Python flushes standard streams on exit; redirect remaining output # to devnull to avoid another BrokenPipeError at shutdown devnull = os.open(os.devnull, os.O_WRONLY) os.dup2(devnull, sys.stdout.fileno()) finally: # Looks like the process that we are piping to is done, let's wrap things up for p in processes: p.terminate() stats['end_time'] = time.time() stats['lines_produced'] = f'{lineno:,}' stats['num_fields'] = num_fields total_time = stats['end_time'] - stats['start_time'] stats['overhead_time'] = overhead_time - stats['start_time'] stats['total_time'] = f"{total_time:,.3f} sec" stats['yield_rate'] = f"{lineno / total_time:,.2f} lines/sec" stats['yield_rate_sans_overhead'] = f"{lineno / (stats['end_time'] - overhead_time):,.2f} lines/sec" logger.info('Summary: ' + json.dumps(stats, indent=2))
if __name__ == "__main__": main()