Skip to content
Snippets Groups Projects
Commit 8ef3726c authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(pp): refresh bud_metric and merging

parent 75951753
No related branches found
No related tags found
No related merge requests found
...@@ -31,8 +31,11 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): ...@@ -31,8 +31,11 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray):
""" """
indices = data.index
if "mother_label" in indices.names:
indices = indices.droplevel("mother_label")
valid_merges, indices = validate_association( valid_merges, indices = validate_association(
merges, np.array(list(data.index)) merges, np.array(list(indices))
) )
# Assign non-merged # Assign non-merged
...@@ -129,14 +132,15 @@ def merge_association( ...@@ -129,14 +132,15 @@ def merge_association(
valid_indices = comparison_mat.any(axis=0) valid_indices = comparison_mat.any(axis=0)
replacement_d = {} if valid_indices.any(): # Where valid, perform transformation
for dataset in grouped_merges: replacement_d = {}
for k in dataset: for dataset in grouped_merges:
replacement_d[tuple(k[0])] = dataset[-1][1] for k in dataset:
replacement_d[tuple(k[0])] = dataset[-1][1]
flat_indices[valid_indices] = [ flat_indices[valid_indices] = [
replacement_d[tuple(i)] for i in flat_indices[valid_indices] replacement_d[tuple(i)] for i in flat_indices[valid_indices]
] ]
merged_indices = flat_indices.reshape(-1, 2, 2) merged_indices = flat_indices.reshape(-1, 2, 2)
return merged_indices return merged_indices
...@@ -204,7 +204,9 @@ class PostProcessor(ProcessABC): ...@@ -204,7 +204,9 @@ class PostProcessor(ProcessABC):
self._writer.write( self._writer.write(
"modifiers/picks", "modifiers/picks",
data=pd.MultiIndex.from_arrays( data=pd.MultiIndex.from_arrays(
picked_indices, names=["trap", "cell_label"] picked_indices.T,
# names=["trap", "cell_label", "mother_label"],
names=["trap", "cell_label"],
), ),
overwrite="overwrite", overwrite="overwrite",
) )
......
...@@ -65,23 +65,36 @@ class BudMetric(LineageProcess): ...@@ -65,23 +65,36 @@ class BudMetric(LineageProcess):
mothers_mat = np.zeros((len(md), signal.shape[1])) mothers_mat = np.zeros((len(md), signal.shape[1]))
cells_were_dropped = 0 # Flag determines if mothers (1), daughters (2) or both were missing (3) cells_were_dropped = 0 # Flag determines if mothers (1), daughters (2) or both were missing (3)
md_index = signal.index.droplevel("mother_label") md_index = signal.index
if md is not None: # Get intersection of lineage and current index if (
"mother_label" not in md_index.names
): # Generate mother label from md dict if unavailable
d = {v: k for k, values in md.items() for v in values}
signal["mother_label"] = list(
map(lambda x: d.get(x, [0])[-1], signal.index)
)
signal.set_index("mother_label", append=True, inplace=True)
related_items = set( related_items = set(
[*md.keys(), *[y for x in md.values() for y in x]] [*md.keys(), *[y for x in md.values() for y in x]]
) )
intersection = md_index.intersection(related_items) md_index = md_index.intersection(related_items)
if len(intersection) < len(signal): elif "mother_label" in md_index.names:
print("Dropped cells before bud_metric") # TODO log md_index = md_index.droplevel("mother_label")
else:
signal = ( raise ("Unavailable relationship information")
signal.reset_index("mother_label")
.loc(axis=0)[intersection] if len(md_index) < len(signal):
.set_index("mother_label", append=True) print("Dropped cells before bud_metric") # TODO log
)
signal = (
signal.reset_index("mother_label")
.loc(axis=0)[md_index]
.set_index("mother_label", append=True)
)
names = list(signal.index.names) names = list(signal.index.names)
del names[-2] del names[-2]
output_df = ( output_df = (
signal.loc[signal.index.get_level_values("mother_label") > 0] signal.loc[signal.index.get_level_values("mother_label") > 0]
.groupby(names) .groupby(names)
...@@ -101,7 +114,7 @@ def _combine_daughter_tracks(tracks: t.Collection[pd.Series]): ...@@ -101,7 +114,7 @@ def _combine_daughter_tracks(tracks: t.Collection[pd.Series]):
prioritising the most recent entity. prioritising the most recent entity.
""" """
sorted_da_ids = tracks.sort_index(level="cell_label") sorted_da_ids = tracks.sort_index(level="cell_label")
tp_fvt = sorted_da_ids.apply(lambda x: x.last_valid_index(), axis=0) tp_fvt = sorted_da_ids.apply(lambda x: x.first_valid_index(), axis=1)
tp_fvt = sorted_da_ids.index.get_indexer(tp_fvt) tp_fvt = sorted_da_ids.index.get_indexer(tp_fvt)
tp_fvt[tp_fvt < 0] = sorted_da_ids.shape[0] - 1 tp_fvt[tp_fvt < 0] = sorted_da_ids.shape[0] - 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment