diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index 06bae6e166156f644523490feb37ebb44f1cca50..20c7b950d3a3e4dbf1cb21dec29947467b938bd7 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -186,31 +186,19 @@ class Signal(BridgeH5): merged = copy(data) if isinstance(picks, bool): picks = ( - self.get_picks(names=merged.index.names) + self.get_picks( + names=merged.index.names, path="modifiers/picks/" + ) if picks - else set(merged.index) + else merged.index ) - # TODO : the following needs clarifying - with h5py.File(self.filename, "r") as f: - if "modifiers/picks" in f and picks: - if picks: - return merged.loc[ - set(picks).intersection( - [tuple(x) for x in merged.index] - ) - ] - else: - if isinstance(merged.index, pd.MultiIndex): - empty_lvls = [[] for i in merged.index.names] - index = pd.MultiIndex( - levels=empty_lvls, - codes=empty_lvls, - names=merged.index.names, - ) - else: - index = pd.Index([], name=merged.index.name) - merged = pd.DataFrame([], index=index) - return merged + if picks: + picked_indices = set(picks).intersection( + [tuple(x) for x in merged.index] + ) + return merged.loc[picked_indices] + else: + return merged @cached_property def p_available(self): @@ -272,8 +260,9 @@ class Signal(BridgeH5): dataset: str or list of strs The name of the h5 file or a list of h5 file names. in_minutes: boolean - If True, + If True, convert column headings to times in minutes. lineage: boolean + If True, add mother_label to index. """ try: if isinstance(dataset, str): @@ -316,13 +305,14 @@ class Signal(BridgeH5): names: t.Tuple[str, ...] = ("trap", "cell_label"), path: str = "modifiers/picks/", ) -> t.Set[t.Tuple[int, str]]: - """Get the relevant picks based on names.""" + """Get picks from the h5 file.""" with h5py.File(self.filename, "r") as f: - picks = set() if path in f: picks = set( zip(*[f[path + name] for name in names if name in f[path]]) ) + else: + picks = set() return picks def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame: diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index 77bcc356273820bdbfc97718b4a5b1c1826d9564..b58302c3e69d64b3b84c0fd0a389425eb5c5dbf4 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -27,6 +27,10 @@ from extraction.core.extractor import Extractor, ExtractorParameters from extraction.core.functions.defaults import exparams_from_meta from postprocessor.core.processor import PostProcessor, PostProcessorParameters +# stop warnings from TensorFlow +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" +logging.getLogger("tensorflow").setLevel(logging.ERROR) + class PipelineParameters(ParametersABC): """Define parameters for the steps of the pipeline.""" diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 3af88deba41aeaccca184a4a957005ac9a04beaa..277c7018b091e6f993b45d2f0e41b9ececa9229f 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -177,9 +177,7 @@ class PostProcessor(ProcessABC): lineage_merged = [] if merges.any(): # update lineages after merge events - merged_indices = merge_lineage(lineage, merges) - # remove repeated labels post-merging - lineage_merged = np.unique(merged_indices, axis=0) + lineage_merged = merge_lineage(lineage, merges) self.lineage = _3d_index_to_2d( lineage_merged if len(lineage_merged) else lineage ) @@ -194,9 +192,7 @@ class PostProcessor(ProcessABC): self._writer.write( "modifiers/picks", data=pd.MultiIndex.from_arrays( - picked_indices.T, - # names=["trap", "cell_label", "mother_label"], - names=["trap", "cell_label"], + picked_indices.T, names=["trap", "cell_label"] ), overwrite="overwrite", )