From c30761fbb8ccb8bb1e7f81446b9968ce9c41d4ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20Mu=C3=B1oz?= <alan.munoz@ed.ac.uk> Date: Fri, 17 Mar 2023 19:21:20 +0000 Subject: [PATCH] fix(postprocessing): enforce float16 --- src/agora/io/writer.py | 2 +- src/agora/utils/indexing.py | 28 +++++++++++------- src/postprocessor/core/processor.py | 46 ++++++++++++++++------------- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py index d5dcde8f..0913308f 100644 --- a/src/agora/io/writer.py +++ b/src/agora/io/writer.py @@ -549,7 +549,7 @@ class Writer(BridgeH5): compression=kwargs.get("compression", None), ) dset = f[values_path] - dset[()] = df.values + dset[()] = df.values.astype("float16") # create dateset and write indices if not len(df): # Only write more if not empty diff --git a/src/agora/utils/indexing.py b/src/agora/utils/indexing.py index 2d5a640c..08b4ad13 100644 --- a/src/agora/utils/indexing.py +++ b/src/agora/utils/indexing.py @@ -110,22 +110,30 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): This is useful when converting a signal multiindex before comparing association. """ - columns = np.arange(ndarray.shape[1]) - - return np.stack( - ( - ndarray[:, np.delete(columns, -1)], - ndarray[:, np.delete(columns, -2)], - ), - axis=1, - ) + result = ndarray + if len(ndarray) and ndarray.ndim > 1: + columns = np.arange(ndarray.shape[1]) + + result = np.stack( + ( + ndarray[:, np.delete(columns, -1)], + ndarray[:, np.delete(columns, -2)], + ), + axis=1, + ) + return result def _3d_index_to_2d(array: np.ndarray): """ Opposite to _assoc_indices_to_3d. """ - return np.concatenate((array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1) + result = array + if len(array): + result = np.concatenate( + (array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1 + ) + return result def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray: diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 47529665..6c58011a 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -1,8 +1,8 @@ +import logging import typing as t from itertools import takewhile from typing import Dict, List, Union -import h5py import numpy as np import pandas as pd from tqdm import tqdm @@ -12,13 +12,16 @@ from agora.io.cells import Cells from agora.io.signal import Signal from agora.io.writer import Writer from agora.utils.indexing import ( - _assoc_indices_to_3d, _3d_index_to_2d, + _assoc_indices_to_3d, compare_indices, ) from agora.utils.kymograph import get_index_as_np from postprocessor.core.abc import get_parameters, get_process -from postprocessor.core.lineageprocess import LineageProcessParameters +from postprocessor.core.lineageprocess import ( + LineageProcess, + LineageProcessParameters, +) from postprocessor.core.reshapers.merger import Merger, MergerParameters from postprocessor.core.reshapers.picker import Picker, PickerParameters @@ -160,9 +163,7 @@ class PostProcessor(ProcessABC): "modifiers/merges", data=[np.array(x) for x in merges] ) - lineage = self.picker.cells.mothers_daughters - if lineage.any(): - lineage = _assoc_indices_to_3d(lineage) + lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) lineage_merged = [] indices = get_index_as_np(record) @@ -188,7 +189,7 @@ class PostProcessor(ProcessABC): lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0) self.lineage = _3d_index_to_2d( - np.array(lineage_merged if len(lineage_merged) else lineage) + lineage_merged if len(lineage_merged) else lineage ) self._writer.write( @@ -198,17 +199,14 @@ class PostProcessor(ProcessABC): picked_indices = self.picker.run( self._signal[self.targets["prepost"]["picker"][0]] ) - self._writer.write( - "modifiers/picks", - data=pd.MultiIndex.from_arrays( - # TODO Check if multiindices are still repeated - np.unique(picked_indices, axis=0).T - if indices.any() - else [[], []], - names=["trap", "cell_label"], - ), - overwrite="overwrite", - ) + if picked_indices.any(): + self._writer.write( + "modifiers/picks", + data=pd.MultiIndex.from_arrays( + picked_indices, names=["trap", "cell_label"] + ), + overwrite="overwrite", + ) @staticmethod def pick_mother(a, b): @@ -236,12 +234,13 @@ class PostProcessor(ProcessABC): else: parameters = self.parameters_classfun[process].default() + if process == "buddings": + print("stop") + loaded_process = self.classfun[process](parameters) if isinstance(parameters, LineageProcessParameters): loaded_process.lineage = self.lineage - if process == "bud_metric": - print("stop") for dataset in datasets: if isinstance(dataset, list): # multisignal process signal = [self._signal[d] for d in dataset] @@ -250,7 +249,10 @@ class PostProcessor(ProcessABC): else: raise Exception("Unavailable record") - if len(signal): + if len(signal) and ( + not isinstance(loaded_process, LineageProcess) + or len(loaded_process.lineage) + ): result = loaded_process.run(signal) else: result = pd.DataFrame( @@ -306,6 +308,8 @@ class PostProcessor(ProcessABC): result: Union[List, pd.DataFrame, np.ndarray], metadata: Dict, ): + if not result.any().any(): + logging.getLogger("aliby").warning(f"Record {path} is empty") self._writer.write(path, result, meta=metadata, overwrite="overwrite") -- GitLab