Skip to content
Snippets Groups Projects
Commit f64117ce authored by Alán Muñoz's avatar Alán Muñoz
Browse files

feat(abc): add update and typecheck methods

parent 8e5c205f
No related branches found
No related tags found
No related merge requests found
import typing as t
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from copy import copy from copy import copy
...@@ -25,8 +26,6 @@ class ParametersABC(ABC): ...@@ -25,8 +26,6 @@ class ParametersABC(ABC):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
###
def to_dict(self, iterable="null"): def to_dict(self, iterable="null"):
""" """
Recursive function to return a nested dictionary of the Recursive function to return a nested dictionary of the
...@@ -69,8 +68,6 @@ class ParametersABC(ABC): ...@@ -69,8 +68,6 @@ class ParametersABC(ABC):
dump(self.to_dict(), f) dump(self.to_dict(), f)
return dump(self.to_dict()) return dump(self.to_dict())
###
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
return cls(**d) return cls(**d)
...@@ -100,8 +97,32 @@ class ParametersABC(ABC): ...@@ -100,8 +97,32 @@ class ParametersABC(ABC):
overriden_defaults[k] = v overriden_defaults[k] = v
return cls.from_dict(overriden_defaults) return cls.from_dict(overriden_defaults)
def update(self, name: str, new_value):
"""
Update values recursively
if name is a dictionary, inject data where found.
It forbids type changes.
"""
assert name not in (
"parameters",
"params",
), "Attribute can't be named params or parameters"
if name in self.__dict__:
assert check_type_recursive(
getattr(self, name), new_value
), "Type changes are not valid"
if isinstance(getattr(self, name), dict):
new_d = getattr(self, name)
new_d.update(new_value)
setattr(self, name, new_d)
### else:
setattr(self, name, new_value)
class ProcessABC(ABC): class ProcessABC(ABC):
...@@ -129,3 +150,23 @@ class ProcessABC(ABC): ...@@ -129,3 +150,23 @@ class ProcessABC(ABC):
@abstractmethod @abstractmethod
def run(self): def run(self):
pass pass
def check_type_recursive(val1, val2):
same_types = True
if not isinstance(val1, type(val2)) and not all(
type(x) in (PosixPath, str) for x in (val1, val2) # Ignore str->path
):
return False
if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable):
return isinstance(val1, type(val2))
elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)):
return bool(
sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)])
)
elif isinstance(val1, dict) and isinstance(val2, dict):
if not len(val1) or not len(val2):
return False
for k in val2.keys():
same_types = same_types and check_type_recursive(val1[k], val2[k])
return same_types
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment