From c23088daf4a7f0471a0162b79fb2371f8a0f8358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Thu, 2 Mar 2023 00:20:49 +0000 Subject: [PATCH] clean(aliby): Improve docs/cleanup --- src/agora/io/cells.py | 1 - src/aliby/utils/vis_tools.py | 21 +++++++++++++++------ src/postprocessor/chainer.py | 7 ++----- src/postprocessor/core/processor.py | 1 + src/postprocessor/grouper.py | 1 - 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 72cc641a..6c783ee3 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -1,6 +1,5 @@ import logging import typing as t -from collections.abc import Iterable from itertools import groupby from pathlib import Path, PosixPath from functools import lru_cache, cached_property diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py index aa388411..2436b421 100644 --- a/src/aliby/utils/vis_tools.py +++ b/src/aliby/utils/vis_tools.py @@ -35,7 +35,7 @@ def get_tiles_at_times( int, t.List[int], str, t.Callable ] = lambda x: concatenate_dims(x, 1, -1), channel: int = 1, -): +) -> np.ndarray: """Use Image and tiler to get tiled position for specific time points. Parameters @@ -66,11 +66,13 @@ def get_tiles_at_times( return tp_channel_stack -def get_cellmasks_at_times(results_path: str, timepoints: t.List[int] = [0]): +def get_cellmasks_at_times( + results_path: str, timepoints: t.List[int] = [0] +) -> t.List[t.List[np.ndarray]]: return Cells(results_path).at_times(timepoints) -def concatenate_dims(ndarray, axis1: int, axis2: int): +def concatenate_dims(ndarray, axis1: int, axis2: int) -> np.ndarray: axis2 = len(ndarray.shape) + axis2 if axis2 < 0 else axis2 return np.concatenate(np.moveaxis(ndarray, axis1, 0), axis=axis2 - 1) @@ -202,13 +204,20 @@ def overlay_masks_tiles( def _sample_n_tiles_masks( - image_path: str, results_path: str, n: int, seed: int = 0 + image_path: str, + results_path: str, + n: int, + seed: int = 0, + interval=None, ) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]: cells = Cells(results_path) - locations, masks = cells._sample_masks(n, seed=seed) + locations, masks = cells._sample_masks(n, seed=seed, interval=interval) processed_tiles, cropped_masks = overlay_masks_tiles( - image_path, results_path, masks, [locations[i] for i in (0, 2)] + image_path, + results_path, + masks, + [locations[i] for i in (0, 2)], ) return locations, (processed_tiles, cropped_masks) diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index 831d5471..b834fb5d 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -7,10 +7,8 @@ from copy import copy import pandas as pd from agora.io.signal import Signal -from agora.utils.association import validate_association from agora.utils.kymograph import bidirectional_retainment_filter -from postprocessor.core.abc import get_parameters, get_process -from postprocessor.core.lineageprocess import LineageProcessParameters +from postprocessor.core.abc import get_process class Chainer(Signal): @@ -62,13 +60,12 @@ class Chainer(Signal): data = self.common_chains[dataset](**kwargs) else: # use Signal's get_raw - data = self.get_raw(dataset, in_minutes=in_minutes) + data = self.get_raw(dataset, in_minutes=in_minutes, lineage=True) if chain: data = self.apply_chain(data, chain, **kwargs) if retain: # keep data only from early time points data = self.get_retained(data, retain) - # data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain] if stages and "stage" not in data.columns.names: # return stages as additional column level stages_index = [ diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 4b9161c0..7a8ccd79 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -44,6 +44,7 @@ class PostProcessorParameters(ParametersABC): @classmethod def default(cls, kind=[]): + """Sequential postprocesses to be operated""" targets = { "prepost": { "merger": "/extraction/general/None/area", diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 6b9fa9ee..be1e85f9 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -371,7 +371,6 @@ def concat_standard( return combined -# why _ind ? def concat_signal_ind( path: str, chainer: Chainer, -- GitLab