From c23088daf4a7f0471a0162b79fb2371f8a0f8358 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk>
Date: Thu, 2 Mar 2023 00:20:49 +0000
Subject: [PATCH] clean(aliby): Improve docs/cleanup

---
 src/agora/io/cells.py               |  1 -
 src/aliby/utils/vis_tools.py        | 21 +++++++++++++++------
 src/postprocessor/chainer.py        |  7 ++-----
 src/postprocessor/core/processor.py |  1 +
 src/postprocessor/grouper.py        |  1 -
 5 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 72cc641a..6c783ee3 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -1,6 +1,5 @@
 import logging
 import typing as t
-from collections.abc import Iterable
 from itertools import groupby
 from pathlib import Path, PosixPath
 from functools import lru_cache, cached_property
diff --git a/src/aliby/utils/vis_tools.py b/src/aliby/utils/vis_tools.py
index aa388411..2436b421 100644
--- a/src/aliby/utils/vis_tools.py
+++ b/src/aliby/utils/vis_tools.py
@@ -35,7 +35,7 @@ def get_tiles_at_times(
         int, t.List[int], str, t.Callable
     ] = lambda x: concatenate_dims(x, 1, -1),
     channel: int = 1,
-):
+) -> np.ndarray:
     """Use Image and tiler to get tiled position for specific time points.
 
     Parameters
@@ -66,11 +66,13 @@ def get_tiles_at_times(
     return tp_channel_stack
 
 
-def get_cellmasks_at_times(results_path: str, timepoints: t.List[int] = [0]):
+def get_cellmasks_at_times(
+    results_path: str, timepoints: t.List[int] = [0]
+) -> t.List[t.List[np.ndarray]]:
     return Cells(results_path).at_times(timepoints)
 
 
-def concatenate_dims(ndarray, axis1: int, axis2: int):
+def concatenate_dims(ndarray, axis1: int, axis2: int) -> np.ndarray:
     axis2 = len(ndarray.shape) + axis2 if axis2 < 0 else axis2
     return np.concatenate(np.moveaxis(ndarray, axis1, 0), axis=axis2 - 1)
 
@@ -202,13 +204,20 @@ def overlay_masks_tiles(
 
 
 def _sample_n_tiles_masks(
-    image_path: str, results_path: str, n: int, seed: int = 0
+    image_path: str,
+    results_path: str,
+    n: int,
+    seed: int = 0,
+    interval=None,
 ) -> t.Tuple[t.Tuple, t.Tuple[np.ndarray, np.ndarray]]:
 
     cells = Cells(results_path)
-    locations, masks = cells._sample_masks(n, seed=seed)
+    locations, masks = cells._sample_masks(n, seed=seed, interval=interval)
 
     processed_tiles, cropped_masks = overlay_masks_tiles(
-        image_path, results_path, masks, [locations[i] for i in (0, 2)]
+        image_path,
+        results_path,
+        masks,
+        [locations[i] for i in (0, 2)],
     )
     return locations, (processed_tiles, cropped_masks)
diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py
index 831d5471..b834fb5d 100644
--- a/src/postprocessor/chainer.py
+++ b/src/postprocessor/chainer.py
@@ -7,10 +7,8 @@ from copy import copy
 import pandas as pd
 
 from agora.io.signal import Signal
-from agora.utils.association import validate_association
 from agora.utils.kymograph import bidirectional_retainment_filter
-from postprocessor.core.abc import get_parameters, get_process
-from postprocessor.core.lineageprocess import LineageProcessParameters
+from postprocessor.core.abc import get_process
 
 
 class Chainer(Signal):
@@ -62,13 +60,12 @@ class Chainer(Signal):
             data = self.common_chains[dataset](**kwargs)
         else:
             # use Signal's get_raw
-            data = self.get_raw(dataset, in_minutes=in_minutes)
+            data = self.get_raw(dataset, in_minutes=in_minutes, lineage=True)
             if chain:
                 data = self.apply_chain(data, chain, **kwargs)
         if retain:
             # keep data only from early time points
             data = self.get_retained(data, retain)
-            # data = data.loc[data.notna().sum(axis=1) > data.shape[1] * retain]
         if stages and "stage" not in data.columns.names:
             # return stages as additional column level
             stages_index = [
diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py
index 4b9161c0..7a8ccd79 100644
--- a/src/postprocessor/core/processor.py
+++ b/src/postprocessor/core/processor.py
@@ -44,6 +44,7 @@ class PostProcessorParameters(ParametersABC):
 
     @classmethod
     def default(cls, kind=[]):
+        """Sequential postprocesses to be operated"""
         targets = {
             "prepost": {
                 "merger": "/extraction/general/None/area",
diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py
index 6b9fa9ee..be1e85f9 100644
--- a/src/postprocessor/grouper.py
+++ b/src/postprocessor/grouper.py
@@ -371,7 +371,6 @@ def concat_standard(
     return combined
 
 
-# why _ind ?
 def concat_signal_ind(
     path: str,
     chainer: Chainer,
-- 
GitLab