diff --git a/README.md b/README.md index 7ac9f31d83592d7af0e4aa667ca75547ca4a09ce..5b5557465d4372bfaacb245645a6aff2543deb96 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ See our [installation instructions]( https://aliby.readthedocs.io/en/latest/INST ### Raw data access +ALIBY's tooling can also be used as an interface to OMERO servers, taking care of fetching data when needed. ```python from aliby.io.dataset import Dataset from aliby.io.image import Image @@ -96,15 +97,6 @@ trap_tps = riv.get_trap_timepoints(trap_id, trange, ncols) This can take several seconds at the moment. For a speed-up: take fewer z-positions if you can. -If you're not sure what indices to use: -```python -seg_expt.channels # Get a list of channels -channel = 'Brightfield' -ch_id = seg_expt.get_channel_index(channel) - -n_traps = seg_expt.n_traps # Get the number of traps -``` - #### Get the traps for a given time point Alternatively, if you want to get all the traps at a given timepoint: @@ -116,4 +108,4 @@ seg_expt.get_tiles_timepoints(timepoint, tile_size=96, channels=None, ### Contributing -See [CONTRIBUTING.md](./CONTRIBUTING.md) for installation instructions. +See [CONTRIBUTING](https://aliby.readthedocs.io/en/latest/INSTALL.html) on how to help out or get involved. diff --git a/docs/source/index.rst b/docs/source/index.rst index 0f8de1dc587afd80aba1db9a6d6ea538d07533bf..1ce4533b4eae35887939cf7387715fdf4791bc48 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,11 +13,12 @@ Reference <api.rst> .. Contributing <CONTRIBUTING.md> - ALIBY reference <_autosummary/aliby> - extraction reference <_autosummary/extraction> - agora reference <_autosummary/agora> - postprocessor reference <_autosummary/postprocessor> - logfile_parser reference <_autosummary/logfile_parser> + .. + ALIBY reference <_autosummary/aliby> + extraction reference <_autosummary/extraction> + agora reference <_autosummary/agora> + postprocessor reference <_autosummary/postprocessor> + logfile_parser reference <_autosummary/logfile_parser> .. include:: ../../README.md :parser: myst_parser.sphinx_ diff --git a/poetry.lock b/poetry.lock index aae36e42d9826f372fe644922c32a61eb2f0bb3f..93299975b973235d4feb467f04410868f3dd6c7e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -27,7 +27,7 @@ python-versions = "*" [[package]] name = "aliby-baby" -version = "0.1.14" +version = "0.1.15" description = "Birth Annotator for Budding Yeast" category = "main" optional = false @@ -2815,7 +2815,7 @@ omero = ["omero-py"] [metadata] lock-version = "1.1" python-versions = ">=3.8, <3.11" -content-hash = "1a3d00dd7aa638b2c4da93e14df10aa17610dd4016a458a0cd2218b0d68814cc" +content-hash = "8ad7c71775623d47d8924e97925402bac8aa03eb9e21cd3abb624ee7da29bc3c" [metadata.files] absl-py = [ @@ -2831,8 +2831,8 @@ alabaster = [ {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, ] aliby-baby = [ - {file = "aliby-baby-0.1.14.tar.gz", hash = "sha256:836bd9af27d5d750f440238e17b157a2b0d4586387864bece925d391320005e8"}, - {file = "aliby_baby-0.1.14-py3-none-any.whl", hash = "sha256:79e838e80c429af1cb9577df18ece8145500a07e2416bb3bbe7afd11dd7bc0b1"}, + {file = "aliby-baby-0.1.15.tar.gz", hash = "sha256:3563f1a740e4a33dcb35be9242a3eca4ee4ca8f373590a6b74b5cc4d1ec29f37"}, + {file = "aliby_baby-0.1.15-py3-none-any.whl", hash = "sha256:b2da8553ab6a59c72db846e6d40b73b03248b5af04b931e86026987ea23f7eba"}, ] appdirs = [ {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, diff --git a/pyproject.toml b/pyproject.toml index f4a0759b1839fd63ef030ef46a1a972491280b42..edde345dea2ca79d89c6dadea3495f47c6894564 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aliby" -version = "0.1.43" +version = "0.1.46" description = "Process and analyse live-cell imaging data" authors = ["Alan Munoz <alan.munoz@ed.ac.uk>"] packages = [ @@ -27,7 +27,7 @@ scikit-learn = ">=1.0.2" # Used for an extraction metric scipy = ">=1.7.3" # [tool.poetry.group.pipeline.dependencies] -aliby-baby = "^0.1.14" +aliby-baby = "^0.1.15" dask = "^2021.12.0" imageio = "2.8.0" # For image-visualisation utilities requests-toolbelt = "^0.9.1" diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index a8ca61b0e13c28b266d5bb2b9bd8fd02384106e3..095b59c4d15448ca30c5a9786fdddcfb9733f95b 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -160,8 +160,11 @@ class Cells: def group_by_traps( self, traps: t.Collection, cell_labels: t.Collection ) -> t.Dict[int, t.List[int]]: - # returns a dict with traps as keys and list of labels as value - # Data is a + """ + Returns a dict with traps as keys and list of labels as value. + Note that the total number of traps are calculated from Cells.traps. + + """ iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) d = {key: [x[1] for x in group] for key, group in iterator} d = {i: d.get(i, []) for i in self.traps} diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 673b13047b9d864294ae36fd9eebc96b365d2004..7f8af04818c5b8dd29647f806ddadf0935e5c46d 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -2,12 +2,14 @@ import typing as t from copy import copy from pathlib import PosixPath +import bottleneck as bn import h5py import numpy as np import pandas as pd from utils_find_1st import cmp_larger, find_1st from agora.io.bridge import BridgeH5 +from agora.io.decorators import _first_arg_str_to_df class Signal(BridgeH5): @@ -22,7 +24,13 @@ class Signal(BridgeH5): def __init__(self, file: t.Union[str, PosixPath]): super().__init__(file, flag=None) - self.names = ["experiment", "position", "trap"] + self.index_names = ( + "experiment", + "position", + "trap", + "cell_label", + "mother_label", + ) def __getitem__(self, dsets: t.Union[str, t.Collection]): @@ -79,7 +87,7 @@ class Signal(BridgeH5): @staticmethod def get_retained(df, cutoff): - return df.loc[df.notna().sum(axis=1) > df.shape[1] * cutoff] + return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff] def retained(self, signal, cutoff=0.8): @@ -90,38 +98,68 @@ class Signal(BridgeH5): elif isinstance(df, list): return [self.get_retained(d, cutoff=cutoff) for d in df] - def apply_prepost(self, dataset: str, skip_pick: t.Optional[bool] = None): + def lineage( + self, lineage_location: t.Optional[str] = None, merged: bool = False + ) -> np.ndarray: + """ + Return lineage data from a given location as a matrix where + the first column is the trap id, + the second column is the mother label and + the third column is the daughter label. + """ + if lineage_location is None: + lineage_location = "postprocessing/lineage" + if merged: + lineage_location += "_merged" + + with h5py.File(self.filename, "r") as f: + trap_mo_da = f[lineage_location] + lineage = np.array( + ( + trap_mo_da["trap"], + trap_mo_da["mother_label"], + trap_mo_da["daughter_label"], + ) + ).T + return lineage + + @_first_arg_str_to_df + def apply_prepost( + self, + data: t.Union[str, pd.DataFrame], + merges: np.ndarray = None, + picks: t.Optional[bool] = None, + ): """ Apply modifier operations (picker, merger) to a given dataframe. """ - merges = self.get_merges() - df = self.get_raw(dataset) - merged = copy(df) + if merges is None: + merges = self.get_merges() + merged = copy(data) + if merges.any(): # Split in two dfs, one with rows relevant for merging and one # without them - valid_merges = merges[ - ( - merges[:, :, :, None] - == np.array(list(df.index)).T[:, None, :] - ) - .all(axis=(1, 2)) - .any(axis=1) - ] # Casting allows fast multiindexing + valid_merges = validate_merges(merges, np.array(list(data.index))) + # TODO use the same info from validate_merges to select both + valid_indices = [ + tuple(x) + for x in (np.unique(valid_merges.reshape(-1, 2), axis=0)) + ] merged = self.apply_merge( - df.loc[map(tuple, valid_merges.reshape(-1, 2))], + data.loc[valid_indices], valid_merges, ) - nonmergeable_ids = df.index.difference(valid_merges.reshape(-1, 2)) + nonmergeable_ids = data.index.difference(valid_indices) merged = pd.concat( - (merged, df.loc[nonmergeable_ids]), names=df.index.names + (merged, data.loc[nonmergeable_ids]), names=data.index.names ) with h5py.File(self.filename, "r") as f: - if "modifiers/picks" in f and not skip_pick: + if "modifiers/picks" in f and not picks: picks = self.get_picks(names=merged.index.names) # missing_cells = [i for i in picks if tuple(i) not in # set(merged.index)] @@ -132,7 +170,7 @@ class Signal(BridgeH5): [tuple(x) for x in merged.index] ) ] - return merged.loc[picks] + else: if isinstance(merged.index, pd.MultiIndex): empty_lvls = [[] for i in merged.index.names] @@ -187,7 +225,7 @@ class Signal(BridgeH5): @property def n_merges(self): - print("{} merge events".format(len(self.merges))) + return len(self.merges) @property def picks(self): @@ -197,7 +235,6 @@ class Signal(BridgeH5): def apply_merge(self, df, changes): if len(changes): - for target, source in changes: df.loc[tuple(target)] = self.join_tracks_pair( df.loc[tuple(target)], df.loc[tuple(source)] @@ -206,11 +243,11 @@ class Signal(BridgeH5): return df - def get_raw(self, dataset, in_minutes=True): + def get_raw(self, dataset: str, in_minutes: bool = True): try: if isinstance(dataset, str): with h5py.File(self.filename, "r") as f: - df = self.dset_to_df(f, dataset) + df = self.dataset_to_df(f, dataset) if in_minutes: df = self.cols_in_mins(df) return df @@ -238,44 +275,52 @@ class Signal(BridgeH5): else: return None - def dset_to_df(self, f, dataset): - dset = f[dataset] - names = copy(self.names) - if not dataset.endswith("imBackground"): - names.append("cell_label") - lbls = {lbl: dset[lbl][()] for lbl in names if lbl in dset.keys()} - index = pd.MultiIndex.from_arrays( - list(lbls.values()), names=names[-len(lbls) :] - ) + def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: + """ + Fetch DataFrame from results storage file. + """ + + assert path in f, f"{path} not in {f}" - columns = ( - dset["timepoint"][()] - if "timepoint" in dset - else dset.attrs["columns"] + dset = f[path] + index_names = copy(self.index_names) + + valid_names = [lbl for lbl in index_names if lbl in dset.keys()] + index = pd.MultiIndex.from_arrays( + [dset[lbl] for lbl in valid_names], names=valid_names ) - df = pd.DataFrame(dset[("values")][()], index=index, columns=columns) + columns = dset.attrs.get("columns", None) # dset.attrs["columns"] + if "timepoint" in dset: + columns = f[path + "/timepoint"][()] - return df + return pd.DataFrame( + f[path + "/values"][()], + index=index, + columns=columns, + ) @property def stem(self): return self.filename.stem - @staticmethod - def dataset_to_df(f: h5py.File, path: str): + # def dataset_to_df(self, f: h5py.File, path: str): - all_indices = ["experiment", "position", "trap", "cell_label"] - indices = { - k: f[path][k][()] for k in all_indices if k in f[path].keys() - } - return pd.DataFrame( - f[path + "/values"][()], - index=pd.MultiIndex.from_arrays( - list(indices.values()), names=indices.keys() - ), - columns=f[path + "/timepoint"][()], - ) + # all_indices = self.index_names + + # valid_indices = { + # k: f[path][k][()] for k in all_indices if k in f[path].keys() + # } + + # new_index = pd.MultiIndex.from_arrays( + # list(valid_indices.values()), names=valid_indices.keys() + # ) + + # return pd.DataFrame( + # f[path + "/values"][()], + # index=new_index, + # columns=f[path + "/timepoint"][()], + # ) def get_siglist(self, name: str, node): fullname = node.name @@ -306,7 +351,8 @@ class Signal(BridgeH5): @staticmethod def join_tracks_pair(target: pd.Series, source: pd.Series): """ - Join two tracks + Join two tracks and return the new value of the target. + TODO replace this with arrays only. """ tgt_copy = copy(target) end = find_1st(target.values[::-1], 0, cmp_larger) @@ -355,3 +401,44 @@ class Signal(BridgeH5): if end <= self.max_span ] return tuple((stage, ntps) for stage, ntps in zip(self.stages, spans)) + + +def validate_merges(merges: np.ndarray, indices: np.ndarray) -> np.ndarray: + """Select rows from the first array that are present in both. + We use casting for fast multiindexing + + + Parameters + ---------- + merges : np.ndarray + 2-D array where columns are (trap, mother, daughter) or 3-D array where + dimensions are (X, (trap,mother), (trap,daughter)) + indices : np.ndarray + 2-D array where each column is a different level. + + Returns + ------- + np.ndarray + 3-D array with elements in both arrays. + + Examples + -------- + FIXME: Add docs. + + """ + if merges.ndim < 3: + # Reshape into 3-D array for casting if neded + merges = np.stack((merges[:, [0, 1]], merges[:, [0, 2]]), axis=1) + + # Compare existing merges with available indices + # Swap trap and label axes for the merges array to correctly cast + # valid_ndmerges = merges.swapaxes(1, 2)[..., None] == indices.T[:, None, :] + valid_ndmerges = merges[..., None] == indices.T[None, ...] + + # Casting is confusing (but efficient): + # - First we check the dimension across trap and cell id, to ensure both match + # - Then we check the dimension that crosses all indices, to ensure the pair is present there + # - Finally we check the merge tuples to check which cases have both target and source + valid_merges = merges[valid_ndmerges.all(axis=2).any(axis=2).all(axis=1)] + # valid_merges = merges[allnan.any(axis=1)] + return valid_merges diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index d3030bfd572e255080c8b97e9d7a983f968baada..01e3d64915e49a87a51ffad253aaaf1c748dfe5d 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -703,10 +703,18 @@ class Writer(BridgeH5): # Add found cells dset.resize(dset.shape[1] + df.shape[1], axis=1) dset[:, ntps:] = np.nan - for i, tp in enumerate(df.columns): - dset[ - self.id_cache[df.index.nlevels]["found_indices"], tp - ] = existing_values[:, i] + + # TODO refactor this indices sorting. Could be simpler + found_indices_sorted = self.id_cache[df.index.nlevels][ + "found_indices" + ] + + # Cover for case when all labels are new + if found_indices_sorted.any(): + # h5py does not allow bidimensional indexing, + # so we have to iterate over the columns + for i, tp in enumerate(df.columns): + dset[found_indices_sorted, tp] = existing_values[:, i] # Add new cells n_newcells = len( self.id_cache[df.index.nlevels]["additional_multis"] diff --git a/src/agora/utils/lineage.py b/src/agora/utils/lineage.py index b72c69cda3243893f8ae7dd31521c280e342b456..5b6686863f0262e515a6164db29c17dbecd80920 100644 --- a/src/agora/utils/lineage.py +++ b/src/agora/utils/lineage.py @@ -1,8 +1,12 @@ #!/usr/bin/env python3 +import re +import typing as t + import numpy as np import pandas as pd from agora.io.bridge import groupsort +from itertools import groupby def mb_array_to_dict(mb_array: np.ndarray): @@ -25,3 +29,51 @@ def mb_array_to_indices(mb_array: np.ndarray): return pd.MultiIndex.from_arrays(mb_array[:, :2].T).union( pd.MultiIndex.from_arrays(mb_array[:, [0, 2]].T) ) + + +def group_matrix( + matrix: np.ndarray, + n_keys: int = 2, +) -> t.Dict[t.Tuple[int], t.List[int]]: + """Group a matrix of integers by grouping the first two columns + and setting the third one in a list. + + + Parameters + ---------- + matrix : np.ndarray + id_matrix, generally its columns are three integers indicating trap, + mother and daughter. + n_keys : int + number of keys to use to determine groups. + + Returns + ------- + t.Dict[t.Tuple[int], t.Collection[int, ...]] + The column(s) not used for generaeting keys are grouped as values. + + Examples + -------- + FIXME: Add docs. + + """ + lineage_dict = {} + if len(matrix): + + daughter = matrix[:, n_keys] + mother_global_id = matrix[:, :n_keys] + + iterator = groupby( + zip(mother_global_id, daughter), lambda x: str(x[0]) + ) + lineage_dict = {key: [x[1] for x in group] for key, group in iterator} + + def str_to_tuple(k: str) -> t.Tuple[int, ...]: + return tuple([int(x) for x in re.findall("[0-9]+", k)]) + + # Convert keys from str to tuple + lineage_dict = { + str_to_tuple(k): sorted(v) for k, v in lineage_dict.items() + } + + return lineage_dict diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 2c4f792c9e1e1df9af661873d0be06d50c0e3cf9..00f50a38df83f31b325b91a8dd5f4e24f017c77d 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -709,6 +709,7 @@ class Pipeline(ProcessABC): meta.add_fields( # Add non-logfile metadata { "aliby_version": version("aliby"), + "baby_version": version("aliby-baby"), "omero_id": config["general"]["id"], "image_id": image_id, "parameters": PipelineParameters.from_dict( diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index ecc34f3f232f7e3b87c18225d56360fb6eb4a38c..e94bdc4be87c4bbcb8748bac29b84305f4f27ba1 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -546,7 +546,7 @@ class Extractor(ProcessABC): ) ) == len(chs): channels_stack = np.stack( - [self.get_imgs(ch, tiles, tree_chs) for ch in chs] + [self.get_imgs(ch, tiles, tree_chs) for ch in chs], axis=-1 ) merged = RED_FUNS[merge_fun](channels_stack, axis=-1) d[name] = self.reduce_extract( diff --git a/src/extraction/core/functions/cell.py b/src/extraction/core/functions/cell.py index d0b605c573d5862c7a4de890a4853b2769171923..cd66bcb4080c4bc4f705e58d334865a9afd73466 100644 --- a/src/extraction/core/functions/cell.py +++ b/src/extraction/core/functions/cell.py @@ -75,10 +75,10 @@ def max2p5pc(cell_mask, trap_image) -> float: """ # number of pixels in mask npixels = bn.nansum(cell_mask) - top_pixels = int(np.ceil(npixels * 0.025)) + n_top = int(np.ceil(npixels * 0.025)) # sort pixels in cell and find highest 2.5% pixels = trap_image[cell_mask] - top_values = pixels[bn.rankdata(pixels)[:top_pixels].astype(int) - 1] + top_values = bn.partition(pixels, len(pixels) - n_top)[-n_top:] # find mean of these highest pixels return bn.nanmean(top_values) @@ -96,7 +96,7 @@ def max5px(cell_mask, trap_image) -> float: """ # sort pixels in cell pixels = trap_image[cell_mask] - top_values = pixels[bn.rankdata(pixels)[:5].astype(int) - 1] + top_values = bn.partition(pixels, len(pixels) - 5)[-5:] # find mean of five brightest pixels max5px = bn.nanmean(top_values) return max5px diff --git a/src/extraction/core/functions/math_utils.py b/src/extraction/core/functions/math_utils.py index b94a0897863af936d3508d66b1af2b427ccf4696..eeae8e0c432e698f27936cec738b17d906b2f59b 100644 --- a/src/extraction/core/functions/math_utils.py +++ b/src/extraction/core/functions/math_utils.py @@ -1,7 +1,7 @@ import numpy as np -def div0(a, b, fill=0): +def div0(array, fill=0, axis=-1): """ Divide array a by array b. @@ -13,9 +13,19 @@ def div0(a, b, fill=0): ---------- a: array b: array + fill: float + **kwargs: kwargs """ + assert array.shape[axis] == 2, f"Array has the wrong shape in axis {axis}" + slices_0, slices_1 = [[slice(None)] * len(array.shape)] * 2 + slices_0[axis] = 0 + slices_1[axis] = 1 + with np.errstate(divide="ignore", invalid="ignore"): - c = np.true_divide(a, b) + c = np.true_divide( + array[tuple(slices_0)], + array[tuple(slices_1)], + ) if np.isscalar(c): return c if np.isfinite(c) else fill else: diff --git a/src/postprocessor/core/abc.py b/src/postprocessor/core/abc.py index b4db424c76c81b283ab1fb8365e93e6d74338b7c..a299e19f6f87a33e0d9a6c5c099b291f377811dc 100644 --- a/src/postprocessor/core/abc.py +++ b/src/postprocessor/core/abc.py @@ -30,11 +30,15 @@ def get_process(process, suffix=""): Dynamically import a process class from the available process locations. Assumes process filename and class name are the same + Processes return the same shape as their input. + MultiSignal either take or return multiple datasets (or both). + Reshapers return a different shape for processes: Merger and Picker belong here. + suffix : str Name of suffix, generally "" (empty) or "Parameters". """ base_location = "postprocessor.core" - possible_locations = ("processes", "multisignal") + possible_locations = ("processes", "multisignal", "reshapers") found = None for possible_location in possible_locations: diff --git a/src/postprocessor/core/export_pdf.py b/src/postprocessor/core/export_pdf.py deleted file mode 100644 index 041789dae04ac711e575e2331e253dea9a71e727..0000000000000000000000000000000000000000 --- a/src/postprocessor/core/export_pdf.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.backends.backend_pdf import PdfPages - - -def dummyplot(): - plt.plot(np.random.randint(10, size=5), np.random.randint(10, size=5)) - - -plt1 = dummyplot() -plt2 = dummyplot() -pp = PdfPages("foo.pdf") -for i in [plt1, plt2]: - pp.savefig(i) -pp.close() diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py index e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe..296effce76612577ff97c8783a2d615389b7fc31 100644 --- a/src/postprocessor/core/lineageprocess.py +++ b/src/postprocessor/core/lineageprocess.py @@ -1 +1,85 @@ -#!/usr/bin/env python3 +import typing as t +from abc import abstractmethod + +import numpy as np +import pandas as pd + +from agora.abc import ParametersABC +from postprocessor.core.abc import PostProcessABC + +# from agora.utils.lineage import group_matrix + + +class LineageProcessParameters(ParametersABC): + """ + Parameters + """ + + _defaults = {} + + +class LineageProcess(PostProcessABC): + """ + Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively) + """ + + def __init__(self, parameters: LineageProcessParameters): + super().__init__(parameters) + + def filter_signal_cells( + self, signal: pd.DataFrame, lineage: np.ndarray = None + ): + """ + Use casting to filter cell ids in signal and lineage + """ + if lineage is None: + lineage = self.lineage + + sig_ind = np.array(list(signal.index)).T[:, None, :] + mo_av = ( + (lineage[:, :2].T[:, :, None] == sig_ind).all(axis=0).any(axis=1) + ) + da_av = ( + (lineage[:, [0, 2]].T[:, :, None] == sig_ind) + .all(axis=0) + .any(axis=1) + ) + + return lineage[mo_av & da_av] + + @abstractmethod + def run( + self, + signal: pd.DataFrame, + lineage: np.ndarray, + *args, + ): + pass + + @classmethod + def as_function( + cls, + data: pd.DataFrame, + lineage: t.Union[t.Dict[t.Tuple[int], t.List[int]]], + *extra_data, + **kwargs, + ): + """ + Overrides PostProcess.as_function classmethod. + Lineage functions require lineage information to be passed if run as function. + """ + # if isinstance(lineage, np.ndarray): + # lineage = group_matrix(lineage, n_keys=2) + + parameters = cls.default_parameters(**kwargs) + return cls(parameters=parameters).run( + data, lineage=lineage, *extra_data + ) + # super().as_function(data, *extra_data, lineage=lineage, **kwargs) + + def load_lineage(self, lineage): + """ + Reshape the lineage information if needed + """ + + self.lineage = lineage diff --git a/src/postprocessor/core/processes/__init__.py b/src/postprocessor/core/processes/__init__.py index e5a0d9b4834ec8f46d6e0d1256c6dcaad2e460fe..bb39f5bd9541f0bfc5c9d57382274e7585061f88 100644 --- a/src/postprocessor/core/processes/__init__.py +++ b/src/postprocessor/core/processes/__init__.py @@ -1 +1,5 @@ #!/usr/bin/env python3 + +""" +All Processes in this folder must return the same shape they take as an input. +""" diff --git a/src/postprocessor/core/processes/lineageprocess.py b/src/postprocessor/core/processes/lineageprocess.py deleted file mode 100644 index 2c0cc6a0e6fa9687f530d96b6ca2f995de16bd6f..0000000000000000000000000000000000000000 --- a/src/postprocessor/core/processes/lineageprocess.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np -import pandas as pd -from agora.abc import ParametersABC - -from postprocessor.core.abc import PostProcessABC - - -class LineageProcessParameters(ParametersABC): - """ - Parameters - """ - - _defaults = {} - - -class LineageProcess(PostProcessABC): - """ - Lineage process that must be passed a (N,3) lineage matrix (where the coliumns are trap, mother, daughter respectively) - """ - - def __init__(self, parameters: LineageProcessParameters): - super().__init__(parameters) - - def run( - self, - ): - pass - - def filter_signal_cells(self, signal: pd.DataFrame): - """ - Use casting to filter cell ids in signal and lineage - """ - - sig_ind = np.array(list(signal.index)).T[:, None, :] - mo_av = ( - (self.lineage[:, :2].T[:, :, None] == sig_ind) - .all(axis=0) - .any(axis=1) - ) - da_av = ( - (self.lineage[:, [0, 2]].T[:, :, None] == sig_ind) - .all(axis=0) - .any(axis=1) - ) - - return self.lineage[mo_av & da_av] diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 5f28dbba3bdb1a36e3ba01de38c41b2a3c06ae67..7cc465c988595226e89bbfcbf8c0c31a6367b4d5 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -11,11 +11,11 @@ from agora.io.writer import Writer from tqdm import tqdm from postprocessor.core.abc import get_parameters, get_process -from postprocessor.core.processes.lineageprocess import ( +from postprocessor.core.lineageprocess import ( LineageProcessParameters, ) -from postprocessor.core.processes.merger import merger, mergerParameters -from postprocessor.core.processes.picker import picker, pickerParameters +from postprocessor.core.reshapers.merger import merger, mergerParameters +from postprocessor.core.reshapers.picker import picker, pickerParameters class PostProcessorParameters(ParametersABC): @@ -298,7 +298,11 @@ class PostProcessor(ProcessABC): self.run_prepost() - for process, datasets in tqdm(self.targets["processes"]): + for i, (process, datasets) in tqdm( + enumerate(self.targets["processes"]) + ): + if i == 3: + print("stop") if process in self.parameters["param_sets"].get( "processes", {} ): # If we assigned parameters @@ -310,23 +314,15 @@ class PostProcessor(ProcessABC): parameters = self.parameters_classfun[process].default() if isinstance(parameters, LineageProcessParameters): - with h5py.File(self._filename, "r") as f: - trap_mo_da = f[parameters.lineage_location] - lineage = np.array( - ( - trap_mo_da["trap"], - trap_mo_da["mother_label"], - trap_mo_da["daughter_label"], - ) - ).T + lineage = self._signal.lineage( + # self.parameters.lineage_location + ) loaded_process = self.classfun[process](parameters) loaded_process.load_lineage(lineage) else: loaded_process = self.classfun[process](parameters) for dataset in datasets: - # print("Processing", process, "for", dataset) - if isinstance(dataset, list): # multisignal process signal = [self._signal[d] for d in dataset] elif isinstance(dataset, str): diff --git a/src/postprocessor/core/multisignal/aggregate.py b/src/postprocessor/core/reshapers/aggregate.py similarity index 100% rename from src/postprocessor/core/multisignal/aggregate.py rename to src/postprocessor/core/reshapers/aggregate.py diff --git a/src/postprocessor/core/processes/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py similarity index 92% rename from src/postprocessor/core/processes/bud_metric.py rename to src/postprocessor/core/reshapers/bud_metric.py index f4d42ab275e8075b82db31a55a502fb9ee7ceefb..be4a978c1c70aafa84c5157de49d447a6c587571 100644 --- a/src/postprocessor/core/processes/bud_metric.py +++ b/src/postprocessor/core/reshapers/bud_metric.py @@ -4,7 +4,7 @@ import numpy as np import pandas as pd from agora.utils.lineage import mb_array_to_dict -from postprocessor.core.processes.lineageprocess import ( +from postprocessor.core.lineageprocess import ( LineageProcess, LineageProcessParameters, ) @@ -74,10 +74,3 @@ class bud_metric(LineageProcess): df = pd.DataFrame(mothers_mat, index=md.keys(), columns=signal.columns) df.index.names = signal.index.names return df - - def load_lineage(self, lineage): - """ - Reshape the lineage information if needed - """ - - self.lineage = lineage diff --git a/src/postprocessor/core/processes/buddings.py b/src/postprocessor/core/reshapers/buddings.py similarity index 97% rename from src/postprocessor/core/processes/buddings.py rename to src/postprocessor/core/reshapers/buddings.py index 839f0fdc0c2e04863ae33ac8b522120ac9256023..0b01dad70c5eab067f26b430e5451fdf70bcb4e2 100644 --- a/src/postprocessor/core/processes/buddings.py +++ b/src/postprocessor/core/reshapers/buddings.py @@ -6,7 +6,7 @@ from itertools import product import numpy as np import pandas as pd -from postprocessor.core.processes.lineageprocess import ( +from postprocessor.core.lineageprocess import ( LineageProcess, LineageProcessParameters, ) diff --git a/src/postprocessor/core/processes/merger.py b/src/postprocessor/core/reshapers/merger.py similarity index 94% rename from src/postprocessor/core/processes/merger.py rename to src/postprocessor/core/reshapers/merger.py index 2f31724e88bd0fec06f35368db4cda27fa38b152..16d2b598ea720027a2d0ca566adfc1b73cc5c31a 100644 --- a/src/postprocessor/core/processes/merger.py +++ b/src/postprocessor/core/reshapers/merger.py @@ -24,7 +24,7 @@ class mergerParameters(ParametersABC): class merger(PostProcessABC): """ - TODO check why it needs to be run a few times to complete the merging + Combines rows of tracklet that are likely to be the same. """ def __init__(self, parameters): diff --git a/src/postprocessor/core/processes/picker.py b/src/postprocessor/core/reshapers/picker.py similarity index 53% rename from src/postprocessor/core/processes/picker.py rename to src/postprocessor/core/reshapers/picker.py index cf06dd3643c415e1a16e5cb59fee36a6d2f25f91..2fc1ece75e5d078dbe05a309bf3a8e30aed06f07 100644 --- a/src/postprocessor/core/processes/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -198,175 +198,9 @@ class picker(PostProcessABC): "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, - "mb_guess": lambda s, p1, p2: self.mb_guess_wrap(s, p1, p2) - # "quantile": [np.quantile(signals.values[signals.notna()], threshold)], } return set(signals.index[case_mgr[condition](signals, *threshold)]) - def mb_guess(self, df, ba, trap, min_budgrowth_t, min_mobud_ratio): - """ - Parameters - ---------- - signals : pd.DataFrame - ba : list of cell_labels that come from bud assignment - trap : Trap id (used to fetch raw bud) - min_budgrowth_t: Minimal number of timepoints we lock reassignment after assigning bud - min_initial_size: Minimal mother-bud ratio when it was first identified - add_ba: Bool that incorporates bud_assignment data after the normal assignment - - Thinking this problem as the Movie Scheduling problem (Skiena's the algorithm design manual chapter 1.2), - we will try to pick the set of filtered cells that grow the fastest and don't overlap within 5 time points - TODO adjust overlap to minutes using metadata - """ - - # if trap == 21: # Use this to check specific trap problems through a debugger - # print("stop") - ntps = df.notna().sum(axis=1) - mother_id = df.index[ntps.argmax()] - nomother = df.drop(mother_id) - if not len(nomother): - return [] - nomother = nomother.loc[ # Clean short-lived cells outside our mother cell's timepoints - nomother.apply( - lambda x: x.first_valid_index() - >= df.loc[mother_id].first_valid_index() - and x.first_valid_index() - <= df.loc[mother_id].last_valid_index(), - axis=1, - ) - ] - - score = -nomother.apply( # Get slope of candidate daughters - lambda x: self.get_slope(x.dropna()), axis=1 - ) - start = nomother.apply(pd.Series.first_valid_index, axis=1) - - # clean duplicates - duplicates = start.duplicated(False) - if duplicates.any(): - score = self.get_nodup_idx(start, score, duplicates, nomother) - nomother = nomother.loc[score.index] - nomother.index = nomother.index.astype("int") - start = start.loc[score.index] - start.index = start.index.astype(int) - - d_to_mother = ( - nomother[start] - df.loc[mother_id, start] * min_mobud_ratio - ).sort_index(axis=1) - size_filter = d_to_mother[ - d_to_mother.apply(lambda x: x.dropna().iloc[0], axis=1) < 0 - ] - cols_sorted = ( - size_filter.sort_index(axis=1) - .apply(pd.Series.first_valid_index, axis=1) - .sort_values() - ) - score = score.loc[cols_sorted.index] - if not len(cols_sorted): - bud_candidates = pd.DataFrame() - else: - # Find the set with the highest number of growing cells and highest avg growth rate for this # - mivs = self.max_ind_vertex_sets( - cols_sorted.values, min_budgrowth_t - ) - best_set = list( - mivs[np.argmin([sum(score.iloc[list(s)]) for s in mivs])] - ) - best_indices = cols_sorted.index[best_set] - - start = start.loc[best_indices] - bud_candidates = cols_sorted.loc[best_indices] - # bud_candidates = cols_sorted.loc[ - # [True, *(np.diff(cols_sorted.values) > min_budgrowth_t)] - # ] - - # Add random-forest bud assignment information here - new_ba_cells = [] - if ( - ba - ): # Use the mother-daughter rf information to prioritise tracks over others - # TODO add merge application to indices and see if that recovers more cells - ba = set(ba).intersection(nomother.index) - ba_df = nomother.loc[ba, :] - start_ba = ba_df.apply(pd.Series.first_valid_index, axis=1) - new_ba_cells = list(set(start_ba.index).difference(start.index)) - - distances = np.subtract.outer( - start.values, start_ba.loc[new_ba_cells].values - ) - todrop, _ = np.where(abs(distances) < min_budgrowth_t) - bud_candidates = bud_candidates.drop(bud_candidates.index[todrop]) - - return [mother_id] + bud_candidates.index.tolist() + new_ba_cells - - @staticmethod - def max_ind_vertex_sets(values, min_distance): - """ - Generates an adjacency matrix from multiple points, joining neighbours closer than min_distance - Then returns the maximal independent vertex sets - values: list of int values - min_distance: int minimal distance to cluster - """ - adj = np.zeros((len(values), len(values))).astype(bool) - dist = abs(np.subtract.outer(values, values)) - adj[dist <= min_distance] = True - - g = ig.Graph.Adjacency(adj, mode="undirected") - miv_sets = g.maximal_independent_vertex_sets() - return miv_sets - - def get_nodup_idx(self, start, score, duplicates, nomother): - """ - Return the start DataFrame without duplicates - - :start: pd.Series indicating the first valid time point - :score: pd.Series containing a score to minimise - :duplicates: Dataframe containing duplicated entries - :nomother: Dataframe with non-mother cells - """ - dup_tps = np.unique(start[duplicates]) - idx, tps = zip( - *[ - (score.loc[nomother.loc[start == tp, tp].index].idxmin(), tp) - for tp in dup_tps - ] - ) - score = score[~duplicates] - score = pd.concat( - (score, pd.Series(tps, index=idx, dtype="int", name="cell_label")) - ) - return score - - def mb_guess_wrap(self, signals, *args): - if not len(signals): - return pd.Series([]) - ids = [] - mothers, buds = self.get_mothers_daughters() - mothers = np.array(mothers) - buds = np.array(buds) - ba = [] - # if buds.any(): - # ba_bytrap = { - # i: np.where(buds[:, 0] == i) for i in range(buds[:, 0].max() + 1) - # } - for trap in signals.index.unique(level="trap"): - # ba = list( - # set(mothers[ba_bytrap[trap], 1][0].tolist()).union( - # buds[ba_bytrap[trap], 1][0].tolist() - # ) - # ) - df = signals.loc[trap] - selected_ids = self.mb_guess(df, ba, trap, *args) - ids += [(trap, i) for i in selected_ids] - - idx_srs = pd.Series(False, signals.index).astype(bool) - idx_srs.loc[ids] = True - return idx_srs - - @staticmethod - def get_slope(x): - return np.polyfit(range(len(x)), x, 1)[0] - def _as_int(threshold: Union[float, int], ntps: int): if type(threshold) is float: