Skip to content
Snippets Groups Projects
Commit d7d53c6a authored by alan's avatar alan
Browse files

[WIP] feat(abc): smarter parameter updating

parent 32bc8ffa
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,9 @@ from pathlib import Path, PosixPath
from typing import Union
from yaml import dump, safe_load
from flatten_dict import flatten
atomic = t.Union[int, float, str, bool]
class ParametersABC(ABC):
......@@ -100,10 +103,13 @@ class ParametersABC(ABC):
def update(self, name: str, new_value):
"""
Update values recursively
if name is a dictionary, inject data where found.
It forbids type changes.
if name is a dictionary, replace data where existing found or add if not.
It warns against type changes.
If the existing structure under name is a dictionary,
it looks for the first occurrence and modifies it accordingly.
If a leaf node that is to be changed is a collection, it adds the new elements.
"""
assert name not in (
......@@ -112,19 +118,59 @@ class ParametersABC(ABC):
), "Attribute can't be named params or parameters"
if name in self.__dict__:
assert check_type_recursive(
getattr(self, name), new_value
), "Type changes are not valid"
if check_type_recursive(getattr(self, name), new_value):
print("Warnings:Type changes are risky")
if isinstance(getattr(self, name), dict):
new_d = getattr(self, name)
new_d.update(new_value)
setattr(self, name, new_d)
flattened = flatten(self.to_dict())
names_found = [k for k in flattened.keys() if name in k]
found_idx = [keys.index(name) for keys in names_found]
assert len(names_found), f"{name} not found as key."
keys = None
if len(names_found) > 1:
for level in zip(found_idx, names_found):
if level == min(found_idx):
keys = level
print(
f"Warning: {name} was found in multiple keys. Selected {keys}"
)
break
else:
keys = names_found.pop()
if keys:
current_val = flattened.get(keys, None)
# if isinstance(current_val, t.Collection):
elif isinstance(getattr(self, name), t.Collection):
add_to_collection(getattr(self, name), new_value)
elif isinstance(getattr(self, name), set):
pass # TODO implement
new_d = getattr(self, name)
new_d.update(new_value)
setattr(self, name, new_d)
else:
setattr(self, name, new_value)
def add_to_collection(
collection: t.Collection, value: t.Union[atomic, t.Collection]
):
# Adds element(s) in place.
if not isinstance(value, t.Collection):
value = [value]
if isinstance(collection, list):
collection += value
elif isinstance(collection, set):
collection.update(value)
class ProcessABC(ABC):
"""
Base class for processes.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment