Skip to content
Snippets Groups Projects
Forked from Swain Lab / aliby / aliby-mirror
55 commits ahead of the upstream repository.
abc.py 7.44 KiB
import logging
import typing as t
from abc import ABC, abstractmethod
from collections.abc import Iterable
from copy import copy
from pathlib import Path
from time import perf_counter
from typing import Union

from flatten_dict import flatten
from yaml import dump, safe_load

from agora.logging import timer

atomic = t.Union[int, float, str, bool]


class ParametersABC(ABC):
    """
    Define parameters as attributes and allow parameters to
    be converted to either a dictionary or to yaml.

    No attribute should be called "parameters"!
    """

    def __init__(self, **kwargs):
        """Define parameters as attributes."""
        assert (
            "parameters" not in kwargs
        ), "No attribute should be named parameters"
        for k, v in kwargs.items():
            setattr(self, k, v)

    def to_dict(self, iterable="null") -> t.Dict:
        """
        Return a nested dictionary of the attributes of the class instance.

        Uses recursion.
        """
        if isinstance(iterable, dict):
            if any(
                [
                    True
                    for x in iterable.values()
                    if isinstance(x, Iterable) or hasattr(x, "to_dict")
                ]
            ):
                return {
                    k: v.to_dict()
                    if hasattr(v, "to_dict")
                    else self.to_dict(v)
                    for k, v in iterable.items()
                }
            else:
                return iterable
        elif iterable == "null":
            # use instance's built-in __dict__ dictionary of attributes
            return self.to_dict(self.__dict__)
        else:
            return iterable

    def to_yaml(self, path: Union[Path, str] = None):
        """
        Return a yaml stream of the attributes of the class instance.

        If path is provided, the yaml stream is saved there.

        Parameters
        ----------
        path : Union[Path, str]
            Output path.
        """
        if path:
            with open(Path(path), "w") as f:
                dump(self.to_dict(), f)
        return dump(self.to_dict())

    @classmethod
    def from_dict(cls, d: dict):
        return cls(**d)

    @classmethod
    def from_yaml(cls, source: Union[Path, str]):
        """Return instance from a yaml filename or stdin."""
        is_buffer = True
        try:
            if Path(source).exists():
                is_buffer = False
        except Exception as _:
            assert isinstance(source, str), "Invalid source type."

        if is_buffer:
            params = safe_load(source)
        else:
            with open(source) as f:
                params = safe_load(f)
        return cls(**params)

    @classmethod
    def default(cls, **kwargs):
        overriden_defaults = copy(cls._defaults)
        for k, v in kwargs.items():
            overriden_defaults[k] = v
        return cls.from_dict(overriden_defaults)

    def update(self, name: str, new_value):
        """
        Update values recursively.

        if name is a dictionary, replace data where existing found or add if not.
        It warns against type changes.

        If the existing structure under name is a dictionary,
        it looks for the first occurrence and modifies it accordingly.

        If a leaf node that is to be changed is a collection, it adds the new elements.
        """

        assert name not in (
            "parameters",
            "params",
        ), "Attribute can't be named params or parameters"

        if name in self.__dict__:
            if check_type_recursive(getattr(self, name), new_value):
                print("Warnings:Type changes are risky")

            if isinstance(getattr(self, name), dict):
                flattened = flatten(self.to_dict())
                names_found = [k for k in flattened.keys() if name in k]
                found_idx = [keys.index(name) for keys in names_found]

                assert len(names_found), f"{name} not found as key."

                keys = None
                if len(names_found) > 1:
                    for level in zip(found_idx, names_found):
                        if level == min(found_idx):
                            keys = level
                            print(
                                f"Warning: {name} was found in multiple keys. Selected {keys}"
                            )
                            break

                else:
                    keys = names_found.pop()

                if keys:
                    current_val = flattened.get(keys, None)
                    # if isinstance(current_val, t.Collection):

            elif isinstance(getattr(self, name), t.Collection):
                add_to_collection(getattr(self, name), new_value)

            elif isinstance(getattr(self, name), set):
                pass  # TODO implement

            new_d = getattr(self, name)
            new_d.update(new_value)
            setattr(self, name, new_d)

        else:
            setattr(self, name, new_value)


def add_to_collection(
    collection: t.Collection, value: t.Union[atomic, t.Collection]
):
    # Adds element(s) in place.
    if not isinstance(value, t.Collection):
        value = [value]
    if isinstance(collection, list):
        collection += value
    elif isinstance(collection, set):
        collection.update(value)


class ProcessABC(ABC):
    """
    Base class for processes.

    Define parameters as attributes and requires a run method.
    """

    def __init__(self, parameters):
        """
        Arguments
        ---------
        parameters: instance of ParametersABC
        """
        self._parameters = parameters
        # convert parameters to dictionary
        # and then define each parameter as an attribute
        for k, v in parameters.to_dict().items():
            setattr(self, k, v)

    @property
    def parameters(self):
        return self._parameters

    @abstractmethod
    def run(self):
        pass

    def _log(self, message: str, level: str = "warning"):
        # Log messages in the corresponding level
        logger = logging.getLogger("aliby")
        getattr(logger, level)(f"{self.__class__.__name__}: {message}")


def check_type_recursive(val1, val2):
    same_types = True
    if not isinstance(val1, type(val2)) and not all(
        type(x) in (Path, str) for x in (val1, val2)  # Ignore str->path
    ):
        return False
    if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable):
        return isinstance(val1, type(val2))
    elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)):
        return bool(
            sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)])
        )
    elif isinstance(val1, dict) and isinstance(val2, dict):
        if not len(val1) or not len(val2):
            return False
        for k in val2.keys():
            same_types = same_types and check_type_recursive(val1[k], val2[k])
    return same_types


class StepABC(ProcessABC):
    """
    Base class that expands on ProcessABC to include tools used by Aliby steps.
    It adds a setup step, logging and benchmarking for time benchmarks.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @abstractmethod
    def _run_tp(self):
        pass

    @timer
    def run_tp(self, tp: int, **kwargs):
        """Time and log the timing of a step."""
        return self._run_tp(tp, **kwargs)

    def run(self):
        # Replace run with run_tp
        raise Warning("Steps use run_tp instead of run.")