diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 095b59c4d15448ca30c5a9786fdddcfb9733f95b..d110869142a071d5d3f054b6a1e172b535e297b0 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -3,7 +3,7 @@ import typing as t from collections.abc import Iterable from itertools import groupby from pathlib import Path, PosixPath -from functools import lru_cache +from functools import lru_cache, cached_property import h5py import numpy as np @@ -264,10 +264,11 @@ class Cells: rand = np.random.randint(mat.sum()) return (traps[rand], tps[rand]) + @lru_cache(20) def mothers_in_trap(self, trap_id: int): return self.mothers[trap_id] - @property + @cached_property def mothers(self): """ Return nested list with final prediction of mother id for each cell @@ -279,24 +280,29 @@ class Cells: self.ntraps, ) - @property - def mothers_daughters(self): + @cached_property + def mothers_daughters(self) -> np.ndarray: + """ + Return mothers and daugters as a single array with three columns: + trap, mothers and daughters + """ nested_massign = self.mothers if sum([x for y in nested_massign for x in y]): - mothers, daughters = zip( - *[ - ((tid, m), (tid, d)) + mothers_daughters = np.array( + [ + (tid, m, d) for tid, trapcells in enumerate(nested_massign) for d, m in enumerate(trapcells, 1) if m - ] + ], + dtype=np.uint16, ) else: - mothers, daughters = ([], []) + mothers_daughters = np.array([]) # print("Warning:Cells: No mother-daughters assigned") - return mothers, daughters + return mothers_daughters @staticmethod def mother_assign_to_mb_matrix(ma: t.List[np.array]): diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 604a7ef047c72a1fc9ff8229ac6a41a5db14dbad..539a71d15ab623c8df68e5867ba62e3df033f244 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -193,38 +193,39 @@ class Signal(BridgeH5): merged = pd.DataFrame([], index=index) return merged - @cached_property + @property def datasets(self): - if not hasattr(self, "_siglist"): - self._siglist = [] + if not hasattr(self, "_available"): + self._available = [] with h5py.File(self.filename, "r") as f: - f.visititems(self.get_siglist) + f.visititems(self.store_signal_url) for sig in self.siglist: print(sig) @cached_property - def p_siglist(self): + def p_available(self): """Print signal list""" self.datasets @cached_property - def siglist(self): - """Return list of signals""" + def available(self): + """Return list of available signals""" try: - if not hasattr(self, "_siglist"): - self._siglist = [] - with h5py.File(self.filename, "r") as f: - f.visititems(self.get_siglist) + if not hasattr(self, "_available"): + self._available = [] + + with h5py.File(self.filename, "r") as f: + f.visititems(self.store_signal_url) + except Exception as e: print("Error visiting h5: {}".format(e)) - self._siglist = [] - return self._siglist + return self._available def get_merged(self, dataset): - return self.apply_prepost(dataset, skip_picks=True) + return self.apply_prepost(dataset, picks=False) @cached_property def merges(self): @@ -327,18 +328,21 @@ class Signal(BridgeH5): # columns=f[path + "/timepoint"][()], # ) - def get_siglist(self, node): - fullname = node.name + def store_signal_url( + self, fullname: str, node: t.Union[h5py.Dataset, h5py.Group] + ): + """ + Store the name of a signal it is a leaf node (a group with no more groups inside) + and starts with extraction + """ if isinstance(node, h5py.Group) and np.all( [isinstance(x, h5py.Dataset) for x in node.values()] ): - self._if_ext_or_post(fullname, self._siglist) + self._if_ext_or_post(fullname, self._available) @staticmethod def _if_ext_or_post(name: str, siglist: list): - if name.startswith("/extraction") or name.startswith( - "/postprocessing" - ): + if name.startswith("extraction") or name.startswith("postprocessing"): siglist.append(name) @staticmethod diff --git a/src/agora/utils/association.py b/src/agora/utils/association.py new file mode 100644 index 0000000000000000000000000000000000000000..13b9cbe408e608b346a1fa91d7704f7739d4e846 --- /dev/null +++ b/src/agora/utils/association.py @@ -0,0 +1,102 @@ +#!/usr/bin/env jupyter +""" +Utilities based on association are used to efficiently acquire indices of tracklets with some kind of relationship. +This can be: + - Cells that are to be merged + - Cells that have a linear relationship +""" + +import numpy as np +import typing as t + + +def validate_association( + association: np.ndarray, + indices: np.ndarray, + match_column: t.Optional[int] = None, +) -> t.Tuple[np.ndarray, np.ndarray]: + + """Select rows from the first array that are present in both. + We use casting for fast multiindexing, generalising for lineage dynamics + + + Parameters + ---------- + association : np.ndarray + 2-D array where columns are (trap, mother, daughter) or 3-D array where + dimensions are (X, (trap,mother), (trap,daughter)) + indices : np.ndarray + 2-D array where each column is a different level. + match_column: int + int indicating a specific column is required to match (i.e. + 0-1 for target-source when trying to merge tracklets or mother-bud for lineage) + must be present in indices. If it is false one match suffices for the resultant indices + vector to be True. + + Returns + ------- + np.ndarray + 1-D boolean array indicating valid merge events. + np.ndarray + 1-D boolean array indicating indices with an association relationship. + + Examples + -------- + + >>> import numpy as np + >>> from agora.utils.association import validate_association + >>> merges = np.array(range(12)).reshape(3,2,2) + >>> indices = np.array(range(6)).reshape(3,2) + + >>> print(merges, indices) + >>> print(merges); print(indices) + [[[ 0 1] + [ 2 3]] + + [[ 4 5] + [ 6 7]] + + [[ 8 9] + [10 11]]] + + [[0 1] + [2 3] + [4 5]] + + >>> valid_associations, valid_indices = validate_association(merges, indices) + >>> print(valid_associations, valid_indices) + [ True False False] [ True True False] + + """ + if association.ndim < 3: + # Reshape into 3-D array for broadcasting if neded + association = np.stack( + (association[:, [0, 1]], association[:, [0, 2]]), axis=2 + ) + + # Compare existing association with available indices + # Swap trap and label axes for the association array to correctly cast + valid_ndassociation = association[..., None] == indices.T[None, ...] + + # Broadcasting is confusing (but efficient): + # First we check the dimension across trap and cell id, to ensure both match + valid_cell_ids = valid_ndassociation.all(axis=2) + + if match_column is None: + # Then we check the merge tuples to check which cases have both target and source + valid_association = valid_cell_ids.any(axis=2).all(axis=1) + + # Finally we check the dimension that crosses all indices, to ensure the pair + # is present in a valid merge event. + valid_indices = ( + valid_ndassociation[valid_association].all(axis=2).any(axis=(0, 1)) + ) + else: # We fetch specific indices if we aim for the ones with one present + valid_indices = valid_cell_ids[:, match_column].any(axis=0) + # Valid association then becomes a boolean array, true means that there is a + # match (match_column) between that cell and the index + valid_association = ( + valid_cell_ids[:, match_column] & valid_indices + ).any(axis=1) + + return valid_association, valid_indices diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index 9a28fd598d6a67fa1c71c28092720f87a016b907..4e4abdb8c2277067c495a69fbcdb3201a39c4aa5 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -9,6 +9,8 @@ import numpy as np import pandas as pd from utils_find_1st import cmp_larger, find_1st +from agora.utils.association import validate_association + def apply_merges(data: pd.DataFrame, merges: np.ndarray): """Split data in two, one subset for rows relevant for merging and one @@ -29,7 +31,9 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): """ - valid_merges, indices = validate_merges(merges, np.array(list(data.index))) + valid_merges, indices = validate_association( + merges, np.array(list(data.index)) + ) # Assign non-merged merged = data.loc[~indices] @@ -49,59 +53,6 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): return merged -def validate_merges( - merges: np.ndarray, indices: np.ndarray -) -> t.Tuple[np.ndarray, np.ndarray]: - - """Select rows from the first array that are present in both. - We use casting for fast multiindexing. - - - - - Parameters - ---------- - merges : np.ndarray - 2-D array where columns are (trap, mother, daughter) or 3-D array where - dimensions are (X, (trap,mother), (trap,daughter)) - indices : np.ndarray - 2-D array where each column is a different level. - - Returns - ------- - np.ndarray - 1-D boolean array indicating valid merge events. - np.ndarray - 1-D boolean array indicating indices involved in merging. - - Examples - -------- - FIXME: Add docs. - - """ - if merges.ndim < 3: - # Reshape into 3-D array for broadcasting if neded - merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1) - - # Compare existing merges with available indices - # Swap trap and label axes for the merges array to correctly cast - # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] - valid_ndmerges = merges[..., None] == indices.T[None, ...] - - # Broadcasting is confusing (but efficient): - # First we check the dimension across trap and cell id, to ensure both match - valid_cell_ids = valid_ndmerges.all(axis=2) - - # Then we check the merge tuples to check which cases have both target and source - valid_merges = valid_cell_ids.any(axis=2).all(axis=1) - - # Finalle we check the dimension that crosses all indices, to ensure the pair - # is present in a valid merge event. - valid_indices = valid_ndmerges[valid_merges].all(axis=2).any(axis=(0, 1)) - - return valid_merges, valid_indices - - def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray: """ Join two tracks and return the new value of the target. diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 2fc1ece75e5d078dbe05a309bf3a8e30aed06f07..117b5d4b1159e85bddfd8710bfc7fa28a58059c8 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -11,8 +11,9 @@ from agora.abc import ParametersABC from agora.io.cells import Cells from utils_find_1st import cmp_equal, find_1st -from postprocessor.core.abc import PostProcessABC +from postprocessor.core.lineageprocess import LineageProcess from postprocessor.core.functions.tracks import max_nonstop_ntps, max_ntps +from agora.utils.association import validate_association class pickerParameters(ParametersABC): @@ -28,7 +29,7 @@ class pickerParameters(ParametersABC): } -class picker(PostProcessABC): +class picker(LineageProcess): """ :cells: Cell object passed to the constructor :condition: Tuple with condition and associated parameter(s), conditions can be @@ -44,124 +45,45 @@ class picker(PostProcessABC): ): super().__init__(parameters=parameters) - self._cells = cells + self.cells = cells - def pick_by_lineage(self, signals, how): - self.orig_signals = signals + def pick_by_lineage(self, signal, how): + self.orig_signals = signal - idx = signals.index - - if how: - mothers = set(self.mothers) - daughters = set(self.daughters) - # daughters, mothers = np.where(mother_bud_mat) - - def search(a, b): - return np.where( - np.in1d( - np.ravel_multi_index( - np.array(a).T, np.array(a).max(0) + 1 - ), - np.ravel_multi_index( - np.array(b).T, np.array(a).max(0) + 1 - ), - ) - ) - - if how == "mothers": - idx = mothers - elif how == "daughters": - idx = daughters - elif how == "daughters_w_mothers": - present_mothers = idx.intersection(mothers) - idx = set( - [ - tuple(x) - for m in present_mothers - for x in np.array(self.daughters)[ - search(self.mothers, m) - ] - ] - ) - - print("associated daughters: ", idx) - elif how == "mothers_w_daughters": - present_daughters = idx.intersection(daughters) - idx = set( - [ - tuple(x) - for d in present_daughters - for x in np.array(self.mothers)[ - search(self.daughters, d) - ] - ] - ) - elif how == "full_families": - present_mothers = idx.intersection(mothers) - dwm_idx = set( - [ - tuple(x) - for m in present_mothers - for x in np.array(self.daughters)[ - search(np.array(self.mothers), m) - ] - ] - ) - present_daughters = idx.intersection(daughters) - mwd_idx = set( - [ - tuple(x) - for d in present_daughters - for x in np.array(self.mothers)[ - search(np.array(self.daughters), d) - ] - ] - ) - idx = mwd_idx.union(dwm_idx) - elif how == "families" or how == "orphans": - families = mothers.union(daughters) - if how == "families": - idx = families - elif how == "orphans": - idx = idx.diference(families) - - idx = idx.intersection(signals.index) + idx = np.array(signal.index.to_list()) + mothers_daughters = self.cells.mothers_daughters + valid_indices, valid_lineage = [slice(None)] * 2 - return idx + if how == "mothers": + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + elif how == "daughters": + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + elif how == "families": # Mothers and daughters that are still present + valid_lineage, valid_indices = validate_association( + mothers_daughters, idx, match_column=0 + ) + + idx = idx[valid_indices] + mothers_daughters = mothers_daughters[valid_lineage] + + return mothers_daughters, idx + + def loc_lineage(self, signals: pd.DataFrame, how: str): + _, valid_indices = self.pick_by_lineage(signals, how) + return signals.loc[valid_indices] def pick_by_condition(self, signals, condition, thresh): idx = self.switch_case(signals, condition, thresh) return idx - def get_mothers_daughters(self): - ma = self._cells["mother_assign_dynamic"] - trap = self._cells["trap"] - label = self._cells["cell_label"] - nested_massign = self._cells.mother_assign_from_dynamic( - ma, label, trap, self._cells.ntraps - ) - # mother_bud_mat = self.mother_assign_to_mb_matrix(nested_massign) - - if sum([x for y in nested_massign for x in y]): - - mothers, daughters = zip( - *[ - ((tid, m), (tid, d)) - for tid, trapcells in enumerate(nested_massign) - for d, m in enumerate(trapcells, 1) - if m - ] - ) - else: - mothers, daughters = ([], []) - print("Warning:Picker: No mother-daughters assigned") - - return mothers, daughters - def run(self, signals): self.orig_signals = signals indices = set(signals.index) - self.mothers, self.daughters = self.get_mothers_daughters() + self.mothers, self.daughters = self.cells.mothers_daughters for alg, op, *params in self.sequence: new_indices = tuple() if indices: diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 1531c052362246ffd4e8231102cb0288f43a20f6..c31e1095384288ad03a6e2f4ba49ffe885de6247 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -53,17 +53,17 @@ class Grouper(ABC): return max(tintervals) @property - def siglist(self) -> None: - return self.fsignal.siglist + def available(self) -> None: + return self.fsignal.available @property - def siglist_grouped(self) -> None: - if not hasattr(self, "_siglist_grouped"): - self._siglist_grouped = Counter( - [x for s in self.signals.values() for x in s.siglist] + def available_grouped(self) -> None: + if not hasattr(self, "_available_grouped"): + self._available_grouped = Counter( + [x for s in self.signals.values() for x in s.available] ) - for s, n in self._siglist_grouped.items(): + for s, n in self._available_grouped.items(): print(f"{s} - {n}") @property @@ -107,7 +107,7 @@ class Grouper(ABC): path = "/" + path # Check the path is in a given signal - sitems = {k: v for k, v in self.signals.items() if path in v.siglist} + sitems = {k: v for k, v in self.signals.items() if path in v.available} nsignals_dif = len(self.signals) - len(sitems) if nsignals_dif: print( @@ -180,7 +180,7 @@ class Grouper(ABC): return tuple(sorted(set(self.positions_groups.keys()))) def ncells( - self, path="/extraction/general/None/area", mode="retained", **kwargs + self, path="extraction/general/None/area", mode="retained", **kwargs ) -> t.Dict[str, int]: """ Returns number of cells retained per position in base channel @@ -388,9 +388,9 @@ class MultiGrouper: group.load_signals() @property - def siglist(self) -> None: + def available(self) -> None: for gpr in self.groupers: - print(gpr.siglist_grouped) + print(gpr.available_grouped) @property def sigtable(self) -> pd.DataFrame: @@ -398,31 +398,31 @@ class MultiGrouper: and experiment.""" def regex_cleanup(x): - x = re.sub(r"\/extraction\/", "", x) - x = re.sub(r"\/postprocessing\/", "", x) - x = re.sub(r"\/np_max", "", x) + x = re.sub(r"extraction\/", "", x) + x = re.sub(r"postprocessing\/", "", x) + x = re.sub(r"\/max", "", x) return x if not hasattr(self, "_sigtable"): raw_mat = [ - [s.siglist for s in gpr.signals.values()] + [s.available for s in gpr.signals.values()] for gpr in self.groupers ] - siglist_grouped = [ + available_grouped = [ Counter([x for y in grp for x in y]) for grp in raw_mat ] - nexps = len(siglist_grouped) + nexps = len(available_grouped) sigs_idx = list( - set([y for x in siglist_grouped for y in x.keys()]) + set([y for x in available_grouped for y in x.keys()]) ) sigs_idx = [regex_cleanup(x) for x in sigs_idx] nsigs = len(sigs_idx) sig_matrix = np.zeros((nsigs, nexps)) - for i, c in enumerate(siglist_grouped): + for i, c in enumerate(available_grouped): for k, v in c.items(): sig_matrix[sigs_idx.index(regex_cleanup(k)), i] = v