From f64117ce357bfec51666e2b738fb9dc26b3ea8c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk>
Date: Wed, 3 Aug 2022 17:46:02 +0100
Subject: [PATCH] feat(abc): add update and typecheck methods

---
 abc.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 46 insertions(+), 5 deletions(-)

diff --git a/abc.py b/abc.py
index ef807be9..339378dd 100644
--- a/abc.py
+++ b/abc.py
@@ -1,3 +1,4 @@
+import typing as t
 from abc import ABC, abstractmethod
 from collections.abc import Iterable
 from copy import copy
@@ -25,8 +26,6 @@ class ParametersABC(ABC):
         for k, v in kwargs.items():
             setattr(self, k, v)
 
-    ###
-
     def to_dict(self, iterable="null"):
         """
         Recursive function to return a nested dictionary of the
@@ -69,8 +68,6 @@ class ParametersABC(ABC):
                 dump(self.to_dict(), f)
         return dump(self.to_dict())
 
-    ###
-
     @classmethod
     def from_dict(cls, d: dict):
         return cls(**d)
@@ -100,8 +97,32 @@ class ParametersABC(ABC):
             overriden_defaults[k] = v
         return cls.from_dict(overriden_defaults)
 
+    def update(self, name: str, new_value):
+        """
+        Update values recursively
+        if name is a dictionary, inject data where found.
+        It forbids type changes.
+
+
+        """
+
+        assert name not in (
+            "parameters",
+            "params",
+        ), "Attribute can't be named params or parameters"
+
+        if name in self.__dict__:
+            assert check_type_recursive(
+                getattr(self, name), new_value
+            ), "Type changes are not valid"
+
+            if isinstance(getattr(self, name), dict):
+                new_d = getattr(self, name)
+                new_d.update(new_value)
+                setattr(self, name, new_d)
 
-###
+        else:
+            setattr(self, name, new_value)
 
 
 class ProcessABC(ABC):
@@ -129,3 +150,23 @@ class ProcessABC(ABC):
     @abstractmethod
     def run(self):
         pass
+
+
+def check_type_recursive(val1, val2):
+    same_types = True
+    if not isinstance(val1, type(val2)) and not all(
+        type(x) in (PosixPath, str) for x in (val1, val2)  # Ignore str->path
+    ):
+        return False
+    if not isinstance(val1, t.Iterable) and not isinstance(val2, t.Iterable):
+        return isinstance(val1, type(val2))
+    elif isinstance(val1, (tuple, list)) and isinstance(val2, (tuple, list)):
+        return bool(
+            sum([check_type_recursive(v1, v2) for v1, v2 in zip(val1, val2)])
+        )
+    elif isinstance(val1, dict) and isinstance(val2, dict):
+        if not len(val1) or not len(val2):
+            return False
+        for k in val2.keys():
+            same_types = same_types and check_type_recursive(val1[k], val2[k])
+    return same_types
-- 
GitLab