Skip to content
Snippets Groups Projects
Commit c30761fb authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(postprocessing): enforce float16

parent 17eac08b
No related branches found
No related tags found
No related merge requests found
...@@ -549,7 +549,7 @@ class Writer(BridgeH5): ...@@ -549,7 +549,7 @@ class Writer(BridgeH5):
compression=kwargs.get("compression", None), compression=kwargs.get("compression", None),
) )
dset = f[values_path] dset = f[values_path]
dset[()] = df.values dset[()] = df.values.astype("float16")
# create dateset and write indices # create dateset and write indices
if not len(df): # Only write more if not empty if not len(df): # Only write more if not empty
......
...@@ -110,22 +110,30 @@ def _assoc_indices_to_3d(ndarray: np.ndarray): ...@@ -110,22 +110,30 @@ def _assoc_indices_to_3d(ndarray: np.ndarray):
This is useful when converting a signal multiindex before comparing association. This is useful when converting a signal multiindex before comparing association.
""" """
columns = np.arange(ndarray.shape[1]) result = ndarray
if len(ndarray) and ndarray.ndim > 1:
return np.stack( columns = np.arange(ndarray.shape[1])
(
ndarray[:, np.delete(columns, -1)], result = np.stack(
ndarray[:, np.delete(columns, -2)], (
), ndarray[:, np.delete(columns, -1)],
axis=1, ndarray[:, np.delete(columns, -2)],
) ),
axis=1,
)
return result
def _3d_index_to_2d(array: np.ndarray): def _3d_index_to_2d(array: np.ndarray):
""" """
Opposite to _assoc_indices_to_3d. Opposite to _assoc_indices_to_3d.
""" """
return np.concatenate((array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1) result = array
if len(array):
result = np.concatenate(
(array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1
)
return result
def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray: def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
......
import logging
import typing as t import typing as t
from itertools import takewhile from itertools import takewhile
from typing import Dict, List, Union from typing import Dict, List, Union
import h5py
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
...@@ -12,13 +12,16 @@ from agora.io.cells import Cells ...@@ -12,13 +12,16 @@ from agora.io.cells import Cells
from agora.io.signal import Signal from agora.io.signal import Signal
from agora.io.writer import Writer from agora.io.writer import Writer
from agora.utils.indexing import ( from agora.utils.indexing import (
_assoc_indices_to_3d,
_3d_index_to_2d, _3d_index_to_2d,
_assoc_indices_to_3d,
compare_indices, compare_indices,
) )
from agora.utils.kymograph import get_index_as_np from agora.utils.kymograph import get_index_as_np
from postprocessor.core.abc import get_parameters, get_process from postprocessor.core.abc import get_parameters, get_process
from postprocessor.core.lineageprocess import LineageProcessParameters from postprocessor.core.lineageprocess import (
LineageProcess,
LineageProcessParameters,
)
from postprocessor.core.reshapers.merger import Merger, MergerParameters from postprocessor.core.reshapers.merger import Merger, MergerParameters
from postprocessor.core.reshapers.picker import Picker, PickerParameters from postprocessor.core.reshapers.picker import Picker, PickerParameters
...@@ -160,9 +163,7 @@ class PostProcessor(ProcessABC): ...@@ -160,9 +163,7 @@ class PostProcessor(ProcessABC):
"modifiers/merges", data=[np.array(x) for x in merges] "modifiers/merges", data=[np.array(x) for x in merges]
) )
lineage = self.picker.cells.mothers_daughters lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters)
if lineage.any():
lineage = _assoc_indices_to_3d(lineage)
lineage_merged = [] lineage_merged = []
indices = get_index_as_np(record) indices = get_index_as_np(record)
...@@ -188,7 +189,7 @@ class PostProcessor(ProcessABC): ...@@ -188,7 +189,7 @@ class PostProcessor(ProcessABC):
lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0) lineage_merged = np.unique(flat_indices.reshape(-1, 2, 2), axis=0)
self.lineage = _3d_index_to_2d( self.lineage = _3d_index_to_2d(
np.array(lineage_merged if len(lineage_merged) else lineage) lineage_merged if len(lineage_merged) else lineage
) )
self._writer.write( self._writer.write(
...@@ -198,17 +199,14 @@ class PostProcessor(ProcessABC): ...@@ -198,17 +199,14 @@ class PostProcessor(ProcessABC):
picked_indices = self.picker.run( picked_indices = self.picker.run(
self._signal[self.targets["prepost"]["picker"][0]] self._signal[self.targets["prepost"]["picker"][0]]
) )
self._writer.write( if picked_indices.any():
"modifiers/picks", self._writer.write(
data=pd.MultiIndex.from_arrays( "modifiers/picks",
# TODO Check if multiindices are still repeated data=pd.MultiIndex.from_arrays(
np.unique(picked_indices, axis=0).T picked_indices, names=["trap", "cell_label"]
if indices.any() ),
else [[], []], overwrite="overwrite",
names=["trap", "cell_label"], )
),
overwrite="overwrite",
)
@staticmethod @staticmethod
def pick_mother(a, b): def pick_mother(a, b):
...@@ -236,12 +234,13 @@ class PostProcessor(ProcessABC): ...@@ -236,12 +234,13 @@ class PostProcessor(ProcessABC):
else: else:
parameters = self.parameters_classfun[process].default() parameters = self.parameters_classfun[process].default()
if process == "buddings":
print("stop")
loaded_process = self.classfun[process](parameters) loaded_process = self.classfun[process](parameters)
if isinstance(parameters, LineageProcessParameters): if isinstance(parameters, LineageProcessParameters):
loaded_process.lineage = self.lineage loaded_process.lineage = self.lineage
if process == "bud_metric":
print("stop")
for dataset in datasets: for dataset in datasets:
if isinstance(dataset, list): # multisignal process if isinstance(dataset, list): # multisignal process
signal = [self._signal[d] for d in dataset] signal = [self._signal[d] for d in dataset]
...@@ -250,7 +249,10 @@ class PostProcessor(ProcessABC): ...@@ -250,7 +249,10 @@ class PostProcessor(ProcessABC):
else: else:
raise Exception("Unavailable record") raise Exception("Unavailable record")
if len(signal): if len(signal) and (
not isinstance(loaded_process, LineageProcess)
or len(loaded_process.lineage)
):
result = loaded_process.run(signal) result = loaded_process.run(signal)
else: else:
result = pd.DataFrame( result = pd.DataFrame(
...@@ -306,6 +308,8 @@ class PostProcessor(ProcessABC): ...@@ -306,6 +308,8 @@ class PostProcessor(ProcessABC):
result: Union[List, pd.DataFrame, np.ndarray], result: Union[List, pd.DataFrame, np.ndarray],
metadata: Dict, metadata: Dict,
): ):
if not result.any().any():
logging.getLogger("aliby").warning(f"Record {path} is empty")
self._writer.write(path, result, meta=metadata, overwrite="overwrite") self._writer.write(path, result, meta=metadata, overwrite="overwrite")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment