Skip to content
Snippets Groups Projects
Commit d89c2cd5 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

fix(post/align): add speed-up changes

parent f68dc2ae
No related branches found
No related tags found
No related merge requests found
...@@ -24,14 +24,35 @@ def df_extend_nan(df, width): ...@@ -24,14 +24,35 @@ def df_extend_nan(df, width):
return out_df return out_df
def df_shift(df, list_index, shift_list): def df_shift(df, shift_list):
"""Shifts each row of each DataFrame by a list of shift intervals """Shifts each row of each DataFrame by a list of shift intervals
Assumes all DataFrames have the same indices (and therefore the same number of rows) Assumes all DataFrames have the same indices (and therefore the same number of rows)
""" """
for index, shift in zip(list_index, shift_list): # Convert to numpy to increase performance
df.loc[index, :] = df.loc[index, :].shift(periods=shift) array = df.to_numpy()
return df
# Sort by shift interval to increase performance
argsort_shift_list = np.argsort(shift_list)
array_sorted = array[argsort_shift_list]
# List of matrices, one for each unique shift interval
matrix_list = []
shift_list_unique = np.unique(shift_list)
for shift_value in shift_list_unique:
# Select the rows of 'array_sorted' that correspond to shift_value
shift_value_matrix = array_sorted[
np.array(shift_list)[argsort_shift_list] == shift_value, :
]
if shift_value != 0:
shift_value_matrix = np.roll(shift_value_matrix, -shift_value)
shift_value_matrix[:, -shift_value:] = np.nan
matrix_list.append(shift_value_matrix)
# Reassemble based on argsort
matrix_list_concat = np.concatenate(matrix_list)
array_shifted = matrix_list_concat[np.argsort(argsort_shift_list)]
return pd.DataFrame(array_shifted, index=df.index, columns=df.columns)
class alignParameters(ParametersABC): class alignParameters(ParametersABC):
...@@ -94,13 +115,17 @@ class align(PostProcessABC): ...@@ -94,13 +115,17 @@ class align(PostProcessABC):
like a mask. For example, this DataFrame can indicate when birth like a mask. For example, this DataFrame can indicate when birth
events are identified for each cell in a dataset. events are identified for each cell in a dataset.
""" """
# Converts mask_df to float if it hasn't been already
# This is so that df_shift() can add np.nans
mask_df += 0.0
# Remove cells that have less than or equal to events_at_least events, # Remove cells that have less than or equal to events_at_least events,
# i.e. if events_at_least = 1, then cells that have no birth events are # i.e. if events_at_least = 1, then cells that have no birth events are
# deleted. # deleted.
event_mask = mask_df.apply( event_mask = (
lambda x: bn.nansum(x) >= self.events_at_least, axis=1 bn.nansum(mask_df.to_numpy(), axis=1) >= self.events_at_least
) )
mask_df = mask_df.iloc[event_mask.to_list()] mask_df = mask_df.iloc[event_mask.tolist()]
# Match trace and event signals by index, e.g. cellID # Match trace and event signals by index, e.g. cellID
# and discard the cells they don't have in common # and discard the cells they don't have in common
...@@ -124,12 +149,8 @@ class align(PostProcessABC): ...@@ -124,12 +149,8 @@ class align(PostProcessABC):
# Remove bits of traces before first event # Remove bits of traces before first event
if self.slice_before_first_event: if self.slice_before_first_event:
# minus sign in front of shift_list to shift to the left # minus sign in front of shift_list to shift to the left
mask_aligned = df_shift( mask_aligned = df_shift(mask_aligned, shift_list)
mask_aligned, common_index.to_list(), -shift_list trace_aligned = df_shift(trace_aligned, shift_list)
)
trace_aligned = df_shift(
trace_aligned, common_index.to_list(), -shift_list
)
# Do not remove bits of traces before first event # Do not remove bits of traces before first event
else: else:
# Add columns to left, filled with NaNs # Add columns to left, filled with NaNs
...@@ -137,11 +158,7 @@ class align(PostProcessABC): ...@@ -137,11 +158,7 @@ class align(PostProcessABC):
mask_aligned = df_extend_nan(mask_aligned, max_shift) mask_aligned = df_extend_nan(mask_aligned, max_shift)
trace_aligned = df_extend_nan(trace_aligned, max_shift) trace_aligned = df_extend_nan(trace_aligned, max_shift)
# shift each # shift each
mask_aligned = df_shift( mask_aligned = df_shift(mask_aligned, shift_list)
mask_aligned, common_index.to_list(), -shift_list trace_aligned = df_shift(trace_aligned, shift_list)
)
trace_aligned = df_shift(
trace_aligned, common_index.to_list(), -shift_list
)
return trace_aligned, mask_aligned return trace_aligned, mask_aligned
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment