diff --git a/dataloader.py b/dataloader.py index fcc02efe9677c890372905bcdb31aa5a4024f5b2..5438f4434bda6106a96a1dc518554405fc8a938d 100644 --- a/dataloader.py +++ b/dataloader.py @@ -267,6 +267,8 @@ class dataloader: self.update_dictionaries(extra_g2a_dict, overwrite_dict) # create instance of grouper grouper = self.get_grouper(dataname) + # find time interval between images + self.dt = grouper.tinterval print("\n---\n" + dataname + "\n---") if bud_fluorescence: # call postprocessor to add bud fluorescence to h5 files @@ -275,12 +277,13 @@ class dataloader: for signal in grouper.available: print(" ", signal) print() - # find time interval between images - self.dt = grouper.tinterval + # get indices for all buds + bud_indices = self.get_bud_indices(grouper, key_index) # get key_index data r_df = self.get_key_index_data( grouper=grouper, key_index=key_index, + bud_indices=bud_indices, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict, ) @@ -288,6 +291,7 @@ class dataloader: r_df = self.load_h5( grouper=grouper, key_index=key_index, + bud_indices=bud_indices, r_df=r_df, interpolate_list=interpolate_list, tmax_in_mins_dict=tmax_in_mins_dict, @@ -314,6 +318,7 @@ class dataloader: self, grouper, key_index, + bud_indices, r_df, interpolate_list, tmax_in_mins_dict, @@ -332,6 +337,7 @@ class dataloader: cutoff=0, tmax_in_mins_dict=tmax_in_mins_dict, ) + record = self.remove_buds(record, bud_indices) # interpolate to remove internal NaNs for signals from mothers if ( interpolate_list @@ -355,7 +361,7 @@ class dataloader: return r_df def get_key_index_data( - self, grouper, key_index, cutoff, tmax_in_mins_dict + self, grouper, key_index, bud_indices, cutoff, tmax_in_mins_dict ): """ Find data for the key record. @@ -367,12 +373,25 @@ class dataloader: record = grouper.concat_signal( key_index_path, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict ) + record = self.remove_buds(record, bud_indices) if record is not None: r_df = self.long_df_with_id(record, key_index) return r_df else: raise Exception(f"{key_index_path} cannot be found.") + def get_bud_indices(self, grouper, key_index): + """Use key_index to get a multi-index for all buds.""" + key_index_path = self.a2g_dict[key_index] + record = grouper.concat_signal(key_index_path, mode="raw_daughters") + bud_indices = record.droplevel("mother_label").index + return bud_indices + + def remove_buds(self, df, bud_indices): + """Remove rows that are buds from a data frame.""" + new_index = df.index.difference(bud_indices) + return df.loc[new_index] + def load_bud_data( self, grouper, @@ -383,6 +402,8 @@ class dataloader: """ Load buddings, bud volume, and any other bud signals. + Bud signals are indexed by their mother cells. + Drop adjacent buddings and interpolate specified bus signals to remove NaNs. """