From f64117ce357bfec51666e2b738fb9dc26b3ea8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Wed, 3 Aug 2022 17:46:02 +0100 Subject: [PATCH] feat(abc): add update and typecheck methods --- abc.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/abc.py b/abc.py index ef807be9..339378dd 100644 --- a/abc.py +++ b/abc.py @@ -1,3 +1,4 @@ +import typing as t from abc import ABC, abstractmethod from collections.abc import Iterable from copy import copy @@ -25,8 +26,6 @@ class ParametersABC(ABC): for k, v in kwargs.items(): setattr(self, k, v) - ### - def to_dict(self, iterable="null"): """ Recursive function to return a nested dictionary of the @@ -69,8 +68,6 @@ class ParametersABC(ABC): dump(self.to_dict(), f) return dump(self.to_dict()) - ### - @classmethod def from_dict(cls, d: dict): return cls(**d) @@ -100,8 +97,32 @@ class ParametersABC(ABC): overriden_defaults[k] = v return cls.from_dict(overriden_defaults) + def update(self, name: str, new_value): + """ + Update values recursively + if name is a dictionary, inject data where found. + It forbids type changes. + + + """ + + assert name not in ( + "parameters", + "params", + ), "Attribute can't be named params or parameters" + + if name in self.__dict__: + assert check_type_recursive( + getattr(self, name), new_value + ), "Type changes are not valid" + + if isinstance(getattr(self, name), dict): + new_d = getattr(self, name) + new_d.update(new_value) + setattr(self, name, new_d) -### + else: + setattr(self, name, new_value) class ProcessABC(ABC): @@ -129,3 +150,23 @@ class ProcessABC(ABC): @abstractmethod def run(self): pass + + +def check_type_recursive(val1, val2): + same_types = True + if not isinstance(val1, type(val2)) and not all( + type(x) in (PosixPath, str) for x in (val1, val2) # Ignore str->path + ): + return False + if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable): + return isinstance(val1, type(val2)) + elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)): + return bool( + sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)]) + ) + elif isinstance(val1, dict) and isinstance(val2, dict): + if not len(val1) or not len(val2): + return False + for k in val2.keys(): + same_types = same_types and check_type_recursive(val1[k], val2[k]) + return same_types -- GitLab