From e5f0a3026270994e6788794ac2b46248fc907387 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <amuoz@ed.ac.uk> Date: Wed, 13 Jul 2022 15:53:20 +0100 Subject: [PATCH] refactor(all): black and flake8 compliant --- abc.py | 25 +++++++++++++------------ io/bridge.py | 25 ++++++++----------------- io/cells.py | 30 ++++++++++++------------------ io/reader.py | 15 ++++++++++----- io/signal.py | 14 +++----------- io/utils.py | 17 ++++++++--------- io/writer.py | 25 ++++++++++++++++--------- utils/example.py | 3 ++- 8 files changed, 72 insertions(+), 82 deletions(-) diff --git a/abc.py b/abc.py index 3c20f39c..03290383 100644 --- a/abc.py +++ b/abc.py @@ -15,10 +15,12 @@ class ParametersABC(ABC): """ def __init__(self, **kwargs): - ''' + """ Defines parameters as attributes - ''' - assert "parameters" not in kwargs, "No attribute should be named parameters" + """ + assert ( + "parameters" not in kwargs + ), "No attribute should be named parameters" for k, v in kwargs.items(): setattr(self, k, v) @@ -38,7 +40,9 @@ class ParametersABC(ABC): ] ): return { - k: v.to_dict() if hasattr(v, "to_dict") else self.to_dict(v) + k: v.to_dict() + if hasattr(v, "to_dict") + else self.to_dict(v) for k, v in iterable.items() } else: @@ -76,16 +80,12 @@ class ParametersABC(ABC): Returns instance from a yaml filename or stdin """ is_buffer = True - try: - if Path(source).exists(): - is_buffer = False - except: - pass + if Path(source).exists(): + is_buffer = False if is_buffer: params = safe_load(source) - else: - with open(source) as f: - params = safe_load(f) + with open(source) as f: + params = safe_load(f) return cls(**params) @classmethod @@ -95,6 +95,7 @@ class ParametersABC(ABC): overriden_defaults[k] = v return cls.from_dict(overriden_defaults) + ### diff --git a/io/bridge.py b/io/bridge.py index c8b0f530..bf1c626a 100644 --- a/io/bridge.py +++ b/io/bridge.py @@ -29,24 +29,10 @@ class BridgeH5: def close(self): self._hdf.close() - def max_ncellpairs(self, nstepsback): - """ - Get maximum number of cell pairs to be calculated - """ - - dset = self._hdf["cell_info"][()] - # attrs = self._hdf[dataset].attrs - pass - @property def cell_tree(self): return self.get_info_tree() - def get_n_cellpairs(self, nstepsback=2): - cell_tree = self.cell_tree - # get pair of consecutive trap-time points - pass - @staticmethod def get_consecutives(tree, nstepsback): # Receives a sorted tree and returns the keys of consecutive elements @@ -83,7 +69,8 @@ class BridgeH5: npairs = [] for t in self._hdf["cell_info"]["processed_timepoints"][()]: tmp_tree = { - k: {k2: v2 for k2, v2 in v.items() if k2 <= t} for k, v in tree.items() + k: {k2: v2 for k2, v2 in v.items() if k2 <= t} + for k, v in tree.items() } npairs.append(self.get_npairs(tree=tmp_tree)) @@ -122,14 +109,18 @@ def groupsort(iterable: Union[tuple, list]): # Sorts iterable and returns a dictionary where the values are grouped by the first element. iterable = sorted(iterable, key=lambda x: x[0]) - grouped = {k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0])} + grouped = { + k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) + } return grouped def recursive_groupsort(iterable): # Recursive extension of groupsort if len(iterable[0]) > 1: - return {k: recursive_groupsort(v) for k, v in groupsort(iterable).items()} + return { + k: recursive_groupsort(v) for k, v in groupsort(iterable).items() + } else: # Only two elements in list return [x[0] for x in iterable] diff --git a/io/cells.py b/io/cells.py index aee65562..45ce385f 100644 --- a/io/cells.py +++ b/io/cells.py @@ -69,14 +69,18 @@ class CellsHDF(Cells): @property def max_labels(self) -> t.List[int]: - with h5py.File(self.filename, mode="r") as f: - return [max(self.labels_in_trap(i)) for i in range(self.ntraps)] + return [max(self.labels_in_trap(i)) for i in range(self.ntraps)] @property def ntraps(self) -> int: with h5py.File(self.filename, mode="r") as f: return len(f["/trap_info/trap_locations"][()]) + @property + def tinterval(self): + with h5py.File(self.filename, mode="r") as f: + return f.attrs["time_settings/timeinterval"] + @property def traps(self) -> t.List[int]: return list(set(self["trap"])) @@ -306,28 +310,18 @@ class CellsLinear(CellsHDF): """ Return nested list with final prediction of mother id for each cell """ - with h5py.File(self.filename, "r") as f: - - return self.mother_assign_from_dynamic( - self["mother_assign_dynamic"], - self["cell_label"], - self["trap"], - self.ntraps, - ) + return self.mother_assign_from_dynamic( + self["mother_assign_dynamic"], + self["cell_label"], + self["trap"], + self.ntraps, + ) @property def mothers_daughters(self): nested_massign = self.mothers if sum([x for y in nested_massign for x in y]): - - idx = set( - [ - (tid, i + 1) - for tid, x in enumerate(nested_massign) - for i in range(len(x)) - ] - ) mothers, daughters = zip( *[ ((tid, m), (tid, d)) diff --git a/io/reader.py b/io/reader.py index bbd7631d..a8d5d369 100644 --- a/io/reader.py +++ b/io/reader.py @@ -67,13 +67,13 @@ class StateReader(DynamicReader): def read_all(self): self.raw_data = { - key: self.read_raw(key, dtype) for key, (_, dtype) in self.datatypes.items() + key: self.read_raw(key, dtype) + for key, (_, dtype) in self.datatypes.items() } return self.raw_data def reconstruct_states(self, data: dict): - ntraps = len(data["max_lbl"]) ntps_back = max(data["tp_back"]) + 1 from copy import copy @@ -88,14 +88,19 @@ class StateReader(DynamicReader): states[k][val_name] = [[] for _ in range(ntps_back)] else: states[k][val_name] = [ - np.zeros((0, data[val_name].shape[1]), dtype=np.float64) + np.zeros( + (0, data[val_name].shape[1]), dtype=np.float64 + ) for _ in range(ntps_back) ] - data[val_name] = list(zip(trap_as_idx, tpback_as_idx, data[val_name])) + data[val_name] = list( + zip(trap_as_idx, tpback_as_idx, data[val_name]) + ) for k, v in groupsort(data[val_name]).items(): states[k][val_name] = [ - np.array([w[0] for w in val]) for val in groupsort(v).values() + np.array([w[0] for w in val]) + for val in groupsort(v).values() ] for val_name in ("lifetime", "p_was_bud", "p_is_mother"): diff --git a/io/signal.py b/io/signal.py index b486737d..77b27e5e 100644 --- a/io/signal.py +++ b/io/signal.py @@ -34,11 +34,9 @@ class Signal(BridgeH5): assert sum(is_bgd) == 0 or sum(is_bgd) == len( dsets ), "Trap data and cell data can't be mixed" - with h5py.File(self.filename, "r") as f: - return [ - self.add_name(self.apply_prepost(dset), dset) - for dset in dsets - ] + return [ + self.add_name(self.apply_prepost(dset), dset) for dset in dsets + ] # return self.cols_in_mins(self.add_name(df, dsets)) return self.add_name(df, dsets) @@ -183,12 +181,6 @@ class Signal(BridgeH5): def n_merges(self): print("{} merge events".format(len(self.merges))) - @property - def merges(self): - with h5py.File(self.filename, "r") as f: - dsets = f.visititems(self._if_merges) - return dsets - @property def picks(self): with h5py.File(self.filename, "r") as f: diff --git a/io/utils.py b/io/utils.py index 25413613..83a70c6d 100644 --- a/io/utils.py +++ b/io/utils.py @@ -4,11 +4,14 @@ Utility functions and classes import itertools import logging import operator +from functools import partial, wraps from pathlib import Path +from time import perf_counter from typing import Callable +import typing as t -import h5py import cv2 +import h5py import numpy as np @@ -89,9 +92,10 @@ class Cache: self._queue.clear() -def accumulate(l: list): - l = sorted(l) - it = itertools.groupby(l, operator.itemgetter(0)) +def accumulate(list_: list) -> t.Generator: + """Accumulate list based on the first value""" + list_ = sorted(list_) + it = itertools.groupby(list_, operator.itemgetter(0)) for key, sub_iter in it: yield key, [x[1] for x in sub_iter] @@ -125,11 +129,6 @@ def parametrized(dec): return layer -from functools import wraps, partial -from time import perf_counter -import logging - - @parametrized def timed(f, name=None): @wraps(f) diff --git a/io/writer.py b/io/writer.py index 19cf7e62..c53ab8a0 100644 --- a/io/writer.py +++ b/io/writer.py @@ -37,8 +37,10 @@ class DynamicWriter: """Append data to existing dataset.""" try: n = len(data) - except: - # Attributes have no length + except Exception as e: + logging.debug( + "DynamicWriter:Attributes have no length: {}".format(e) + ) n = 1 if key not in hgroup: # TODO Include sparsity check @@ -62,9 +64,11 @@ class DynamicWriter: dset = hgroup[key] dset.resize(dset.shape[0] + n, axis=0) dset[-n:] = data - except: + except Exception as e: logging.debug( - "DynamicWriter:Inconsistency between dataset shape and new empty data" + "DynamicWriter:Inconsistency between dataset shape and new empty data: {}".format( + e + ) ) return @@ -228,7 +232,7 @@ class BabyWriter(DynamicWriter): save_complex(current_indices, ix_dset) def __append_edgemasks(self, hgroup, edgemasks, current_indices): - key = "edgemasks" + # key = "edgemasks" val_dset = hgroup["values"] ix_dset = hgroup["indices"] existing_indices = load_complex(ix_dset) @@ -252,7 +256,11 @@ class BabyWriter(DynamicWriter): try: val_dset[ix, n_tps - 1] = mask except Exception as e: - logging.debug(f"{ix}, {n_tps}, {val_dset.shape}") + logging.debug( + "Exception: {}:{}, {}, {}".format( + e, ix, n_tps, val_dset.shape + ) + ) # Save the index values save_complex(missing, ix_dset) @@ -262,7 +270,7 @@ class BabyWriter(DynamicWriter): # DATA is TRAP_IDS, CELL_LABELS, EDGEMASKS in a structured array key = "edgemasks" val_key = "values" - idx_key = "indices" + # idx_key = "indices" # Length of edgemasks traps, cell_labels, edgemasks = data n_cells = len(cell_labels) @@ -503,7 +511,6 @@ class Writer(BridgeH5): data : Iterable, default = None meta : Dict, default = {} - """ self.id_cache = {} with h5py.File(self.filename, "a") as f: @@ -644,7 +651,7 @@ class Writer(BridgeH5): if ( not hasattr(self, "id_cache") - or not df.index.nlevels in self.id_cache + or df.index.nlevels not in self.id_cache ): # Use cache dict to store previously-obtained indices self.id_cache[df.index.nlevels] = {} existing_ids = self.get_existing_ids( diff --git a/utils/example.py b/utils/example.py index 1089ae40..e3ff571a 100644 --- a/utils/example.py +++ b/utils/example.py @@ -48,5 +48,6 @@ def example_function(parameter: Union[int, str]): return ExampleClass(int(parameter)) except ValueError as e: raise ValueError( - f"The parameter {parameter} could not be turned " f"into an integer." + f"The parameter {parameter} could not be turned " + f"into an integer." ) from e -- GitLab