diff --git a/correct_buds.py b/correct_buds.py index 16aa6b6772d51637945b022fb2d7a04c84e53a5f..7f05b43a18dc090a7d9005c213be190a99277c12 100644 --- a/correct_buds.py +++ b/correct_buds.py @@ -110,18 +110,37 @@ def plot_buds(ind, buddings, bud_volume, new_bv=False): plt.show(block=False) -def correct_buds(buddings, bud_volume, figs=True): +def correct_buds(buddings, bud_volume, figs=True, interpolate=True): + """ + Correct buddings and bud_volume. + + Drop adjacent buddings and interpolate over NaNs. + + Parameters + ---------- + buddings: pd.DataFrame + Data frame of time series of budding events. + bud_volume: pd.DataFrame + Data frame of time series of bud volumes. + figs: boolean + If True, plot volumes of all buds marking budding events with + a red circle. + interpolate: boolean + If True, use linear interpolation to remove NaNs within the time + series for each bud. + """ # remove mother label bud_volume = drop_label(bud_volume) # check bud volumes and buddings new_buddings, new_bud_volume = skip_budding(buddings, bud_volume) - for ind in new_buddings.index: - success, new_bv = interp_bud_volume( - ind, - new_buddings, - new_bud_volume, - figs, - ) - if success: - new_bud_volume.loc[ind] = new_bv + if interpolate: + for ind in new_buddings.index: + success, new_bv = interp_bud_volume( + ind, + new_buddings, + new_bud_volume, + figs, + ) + if success: + new_bud_volume.loc[ind] = new_bv return new_buddings, new_bud_volume diff --git a/dataloader.py b/dataloader.py index caff02186f51588eed6044cb07c4665adb2b31d6..029fb6bcba5acbbabfc717a43e1be2efcfebebb1 100644 --- a/dataloader.py +++ b/dataloader.py @@ -154,17 +154,18 @@ class dataloader: dataname, key_index="buddings", cutoff=0.8, + interpolate_list=None, extra_g2a_dict=None, pxsize=0.182, use_tsv=False, - over_write=False, + over_write_dict=False, ): """ Load either an experiment or a data set into a long data frame. The 'time' variable becomes a column. - New data is added to the existing dataframe. + New data is added to the existing data frame. Data is stored in the .df attribute. @@ -172,12 +173,13 @@ class dataloader: ---------- dataname: string Either the name of the directory for an experiment or a - file name for a dataset. + file name for a data set. key_index: string Short name for key record that will be used to select cells. - cut_off: float + cutoff: float Select cells for key record that remain for at least cutoff fraction of the experiment's total duration. + interpolate: list g2a_dict: dictionary (optional) A dictionary of extra parameters to extract, which relates the aliby name to an abbreviation @@ -185,7 +187,7 @@ class dataloader: Pixel size in microns, which is used to convert volumes. use_tsv: boolean If True, always load the data from a tsv file. - over_write: boolean + over_write_dict: boolean If True, overwrite the internal dictionary with extra_g2a_dict. Returns @@ -207,12 +209,16 @@ class dataloader: self.load_tsv(dataname) else: # update dictionary - if extra_g2a_dict and not over_write: + if extra_g2a_dict and not over_write_dict: self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict} elif extra_g2a_dict: self.g2a_dict = extra_g2a_dict # create instance of grouper grouper = self.get_grouper(dataname) + print("\n---" + dataname + "\n---") + print("signals available:") + for signal in grouper.available: + print(" ", signal) # find time interval between images self.dt = grouper.tintervals # get key index for choosing cells @@ -221,70 +227,96 @@ class dataloader: else: key_index_path = self.a2g_dict[key_index] key_index = self.get_key_index(grouper, key_index_path, cutoff) - # load data from h5 files - print(dataname + "\n---") - print("signals available:") - signals = grouper.available - for signal in signals: - print(" ", signal) - print("\nloading...") - # load and correct buddings and bud_volume - r_df = self.load_buddings_bud_volume( - grouper, cutoff, figs=False, key_index=key_index - ) - # load other signals - for i, char in enumerate(self.g2a_dict): - if char in signals: - print(" " + char) - # load all cells - record = grouper.concat_signal(char, cutoff=0) - # keep cells only in key_index - new_record = self.get_key_cells(record, key_index) - # convert to long data frame - tdf = self.long_df_with_id(new_record, self.g2a_dict[char]) - # merge into one data set - r_df = pd.merge(r_df, tdf, how="left") - else: - print(" Warning: " + char, "not found") - if "r_df" in locals(): - if pxsize: - # volumes to micron^3 - for signal in r_df.columns: - if "volume" in signal: - r_df[signal] *= pxsize**3 - if hasattr(self, "df"): - # merge new data to current dataframe - self.df = pd.merge(self.df, r_df, how="left") - else: - # create new dataframe - self.df = r_df + # load from h5 files + r_df = self.load_h5(grouper, key_index, cutoff, interpolate_list) + if pxsize: + # convert volumes to micron^3 + for signal in r_df.columns: + if "volume" in signal: + r_df[signal] *= pxsize**3 + # create new attribute or merge with existing one + if hasattr(self, "df"): + self.df = pd.merge(self.df, r_df, how="left") else: - raise NameError("Dataloader: No data loaded.") + self.df = r_df # define ids self.ids = list(self.df.id.unique()) if not use_tsv: - # return grouper return grouper + def load_h5( + self, + grouper, + key_index, + cutoff, + interpolate_list, + ): + """Load data from h5 files into one long data frame.""" + print("\nloading...") + if interpolate_list and "buddings" in interpolate_list: + bud_interpolate = True + else: + bud_interpolate = False + # load and correct buddings and bud_volume + r_df = self.load_buddings_bud_volume( + grouper, + cutoff, + figs=False, + key_index=key_index, + interpolate=bud_interpolate, + ) + # load other signals + for i, char in enumerate(self.g2a_dict): + if char in grouper.available: + print(" " + char) + # load all cells + record = grouper.concat_signal(char, cutoff=0) + # keep cells only in key_index + new_record = self.get_key_cells(record, key_index) + # interpolate to remove internal NaNs + if ( + interpolate_list + and self.g2a_dict[char] in interpolate_list + ): + new_record = self.interp_signal(new_record) + # convert to long data frame + tdf = self.long_df_with_id(new_record, self.g2a_dict[char]) + # merge into one data set + r_df = pd.merge(r_df, tdf, how="left") + else: + print(" Warning: " + char, "not found") + return r_df + def get_key_index(self, grouper, key_index_path, cutoff): - """Find index of cells that have sufficient measurements of the - key record.""" + """ + Find index of cells that appear in the key record. + + Cells must be retained at least a cutoff fraction of the + experiment's duration. + """ record = grouper.concat_signal(key_index_path, cutoff=cutoff) return record.index def get_key_cells(self, df, key_index): - """Find a smaller data frame with only cells present in the key record.""" + """Find a smaller data frame with only cells from the key record.""" sdf = df.loc[df.index.intersection(key_index)] return sdf - def load_buddings_bud_volume(self, grouper, cutoff, figs, key_index): - """Load buddings and bud_volume, interpolate over NaNs, and - drop adjacent buddings.""" + def load_buddings_bud_volume( + self, grouper, cutoff, figs, key_index, interpolate + ): + """ + Load buddings and bud_volume. + + Drop adjacent buddings and interpolate over NaNs. + """ # load buddings and all bud_volumes buddings = grouper.concat_signal(self.buddings_path, cutoff=cutoff) bud_volume = grouper.concat_signal(self.bud_volume_path, cutoff=0) # perform correction - new_buddings, new_bud_volume = correct_buds(buddings, bud_volume, figs) + new_buddings, new_bud_volume = correct_buds( + buddings, bud_volume, figs, interpolate + ) # keep cells only in key_index new_buddings = self.get_key_cells(new_buddings, key_index) new_bud_volume = self.get_key_cells(new_bud_volume, key_index) @@ -318,7 +350,7 @@ class dataloader: def save(self, dataname=None): """ - Save the .df dataframe to a tab separated value (tsv) file. + Save the .df data frame to a tab separated value (tsv) file. Parameters ---------- @@ -333,7 +365,7 @@ class dataloader: print("Saved", dataname) def long_df_with_id(self, df, char_name): - """Convert an aliby multi-index dataframe into a long dataframe.""" + """Convert an aliby multi-index data frame into a long data frame.""" df = self.long_df(df, char_name) # add unique id for each cell if ( @@ -360,7 +392,7 @@ class dataloader: Convert the 'id' column into the standard columns used by aliby. These columns are 'position', 'trap', and 'cell_label' - and vice - versa, either adding three columns to the .df dataframe or removing + versa, either adding three columns to the .df data frame or removing three columns. """ if ( @@ -384,7 +416,7 @@ class dataloader: def wide_df(self, signal, x="time", y="id"): """ - Pivot the .df dataframe to return a standard aliby dataframe. + Pivot the .df data frame to return a standard aliby data frame. Parameters ---------- @@ -403,8 +435,7 @@ class dataloader: def get_time_series(self, signal, group=None): """ - Extract all the data for a particular signal as - a 2D array with each row a time series. + Extract a signal as a 2D array with each row a time series. Parameters ---------- @@ -435,15 +466,14 @@ class dataloader: def put_time_series(self, values, signal): """ - Insert a 2D array of data with each column a time series - into the dataframe. + Insert a 2D array of data with each column a time series. Parameters ---------- values: array - The data to be inserted + The data to be inserted into the data frame. signal: string - The name of the signal + The name of the signal. """ # find a suitable column in r to generate a pivot cols = list(self.df.columns) @@ -527,15 +557,14 @@ class dataloader: def get_lineage_data(self, idx, signals): """ - Return signals for a single lineage specified by idx, which is - specified by an element of self.ids. + Return signals for a single lineage specified by idx. - Arguments - --------- + Parameters + ---------- idx: integer - Element of self.ids + Element of self.ids. signals: list of strings - Signals to be returned + Signals to be returned. """ if isinstance(signals, str): signals = [signals] @@ -557,3 +586,10 @@ class dataloader: # melt to create time column df = df.melt(id_vars=hnames, var_name=var_name, value_name=char_name) return df + + @staticmethod + def interp_signal(df): + """Use interpolation to remove internal NaNs in data frame.""" + # wide aliby data frame + df = df.interpolate("linear", axis=1, limit_area="inside") + return df