from abc import ABC, abstractmethod
from contextlib import ExitStack
from shutil import copyfileobj
import os
from tempfile import mkstemp
from types import TracebackType
from typing import Any, BinaryIO, Callable, cast, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from .params import ParamSet, Resolvable, resolve
from .io import ensure_writeable, PathType, PrimitivePathType, ReadOpenable, resolve_writable, Writeable, WriteOpenableWrapBinaryIO
from .utils import consume, listify
_T = TypeVar('_T')
_S = TypeVar('_S')
_T_contra = TypeVar('_T_contra', contravariant=True)
# TODO: Check covariant-ness
# Does that even matter? No information about individual results is
# extractable from Collector instances and Collector instances are not
# commonly passed as arguments to other functions.
# In that sense, a Collector is a write-only container. Write-only containers
# should be contra-variant: A function accepting a Collector for type _T
# elements will happily work with an collector for any super-type of _T.
# The elements written by the function are both an instance of _T and of any super-type of _T
[docs]class Collector(Generic[_T_contra], ABC):
"""Collects and processes per-parameter set results.
Collector provides the context manager interface. Each collector is a
non-reentrent context manager. Any long-held resources will only be
acquired upon entering the context manager, i.e. by opening an aggregation
file. These resources will be released when exiting the context manager,
i.e. closing all open files.
"""
def __enter__(self) -> 'Collector[_T_contra]':
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> Optional[bool]:
return None
[docs] @abstractmethod
def collect(self, results: Iterable[Tuple[ParamSet, _T_contra]]) -> None:
"""Collects and processes all elements within ``results``.
This method **must** be called while the context manager is in the
open state.
Collect may be called never, once or multiple times.
Args:
results: Iterable over tuples of the parameter set with the
associated result.
"""
raise NotImplementedError()
[docs]class Noop(Collector[_T_contra], Generic[_T_contra]):
[docs] def collect(self, results: Iterable[Tuple[ParamSet, _T_contra]]) -> None:
consume(results)
[docs]class Multi(Collector[_T_contra], Generic[_T_contra]):
def __init__(self, *collectors: Collector[_T_contra]) -> None:
self._collectors: Tuple[Collector[_T_contra], ...] = collectors
self._stack: ExitStack = ExitStack()
@property
def collectors(self) -> Tuple[Collector[_T_contra], ...]:
return self._collectors
def __enter__(self) -> 'Multi[_T_contra]':
self._stack.__enter__()
for collector in self._collectors:
self._stack.enter_context(collector)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> Optional[bool]:
return self._stack.__exit__(exc_type, exc_val, exc_tb)
[docs] def collect(self, results: Iterable[Tuple[ParamSet, _T_contra]]) -> None:
for params, result in results:
for collector in self._collectors:
collector.collect(((params, result),))
[docs]class Demux(Collector[_T_contra], Generic[_T_contra]):
"""Demux de-multiplexes results, by distributing to different Collectors.
Args:
keys: Keys in the parameter set which to consider during demuxing. For
each distinct combination of values of these keys, a collector is
maintained.
factory: Function to call to create a new collector. A reduced
parameter set is passed as the only argument, containing only
those parameters specified in ``keys``.
"""
def __init__(
self,
keys: Iterable[str],
factory: 'Callable[[ParamSet], Collector[_T_contra]]',
):
super(Demux, self).__init__()
self._keys: Tuple[str, ...] = tuple(keys)
self._factory: Callable[[ParamSet], Collector[_T_contra]] = factory
self._stack: ExitStack = ExitStack()
self._collectors: Dict[Tuple[Any, ...], Collector[_T_contra]] = {}
@property
def keys(self) -> Tuple[str, ...]:
return self._keys
def __enter__(self) -> 'Demux[_T_contra]':
self._stack.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> Optional[bool]:
return self._stack.__exit__(exc_type, exc_val, exc_tb)
[docs] def collect(self, results: Iterable[Tuple[ParamSet, _T_contra]]) -> None:
for params, result in results:
self._route_result(params, result)
def _route_result(self, params: ParamSet, result: _T) -> None:
"""
Note:
mypy prohibits parameters of covariant type. This methods gets
around that through a cast.
This is still safe as the access is read-only.
"""
typed_result = cast(_T_contra, result)
t = tuple(params[key] for key in self._keys)
try:
collector = self._collectors[t]
except KeyError:
collector = self._factory({key: params[key] for key in self._keys})
self._collectors[t] = collector
self._stack.enter_context(collector)
collector.collect(((params, typed_result),))
[docs]class Convert(Collector[_T_contra], Generic[_T_contra, _S]):
def __init__(self, f: Callable[[_T_contra], _S], collector: Collector[_S]) -> None:
self._f: Callable[[_T_contra], _S] = f
self._collector: Collector[_S] = collector
@property
def collector(self) -> Collector[_S]:
return self._collector
def __enter__(self) -> 'Convert[_T_contra, _S]':
self._collector.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional[TracebackType] = None,
) -> Optional[bool]:
return self._collector.__exit__(exc_type, exc_val, exc_tb)
[docs] def collect(self, results: Iterable[Tuple[ParamSet, _T_contra]]) -> None:
self._collector.collect(map(lambda r: (r[0], self._f(r[1])), results))
[docs]class Concatenate(Collector[ReadOpenable]):
"""Concatenates file-like openables into a new file.
Args:
path: The file path to be opened as the aggregate file. If ``None`` a
temporary file is created (which is not deleted).
"""
def __init__(
self,
path: Optional[PathType] = None,
before_all: Optional[Union[Writeable, str, bytes]] = None,
after_all: Optional[Union[Writeable, str, bytes]] = None,
before: Optional[Resolvable[Union[Writeable, str, bytes]]] = None,
after: Optional[Resolvable[Union[Writeable, str, bytes]]] = None,
):
super(Concatenate, self).__init__()
self._before_all: Optional[Union[Writeable, str, bytes]] = before_all
self._after_all: Optional[Union[Writeable, str, bytes]] = after_all
self._before: Optional[Resolvable[Union[Writeable, str, bytes]]] = before
self._after: Optional[Resolvable[Union[Writeable, str, bytes]]] = after
self._aggregate_path: PrimitivePathType
if path is not None:
self._aggregate_path = os.fspath(path)
else:
fd, self._aggregate_path = mkstemp()
os.close(fd)
self._aggregate_file: Optional[BinaryIO] = None
@property
def path(self) -> Union[str, bytes]:
return self._aggregate_path
def __enter__(self) -> 'Concatenate':
if self._aggregate_file is not None:
raise Exception('Wrong usage. _aggregate_file is set but is expected to not be.')
self._aggregate_file = open(self._aggregate_path, 'wb')
if self._before_all is not None:
openable = WriteOpenableWrapBinaryIO(self._aggregate_file)
ensure_writeable(self._before_all).write_to(openable)
return self
def __exit__(self, *args: Any) -> Optional[bool]:
if self._aggregate_file is None:
raise Exception('Wrong usage. _aggregate_file is not set but is expected to be.')
if self._after_all is not None:
openable = WriteOpenableWrapBinaryIO(self._aggregate_file)
ensure_writeable(self._after_all).write_to(openable)
self._aggregate_file.close()
self._aggregate_file = None
return None
[docs] def collect(self, results: Iterable[Tuple[ParamSet, ReadOpenable]]) -> None:
if self._aggregate_file is None:
raise Exception('Wrong usage. _aggregate_file is not set but is expected to be.')
for params, result in results:
if self._before is not None:
openable = WriteOpenableWrapBinaryIO(self._aggregate_file)
resolve_writable(self._before, params).write_to(openable)
with result.open_bytes() as result_file:
copyfileobj(result_file, self._aggregate_file)
if self._after is not None:
openable = WriteOpenableWrapBinaryIO(self._aggregate_file)
resolve_writable(self._after, params).write_to(openable)