From 12d0f2e5258d22f0107f3ae5511cfdcfc8a1d1a0 Mon Sep 17 00:00:00 2001 From: pswain <peter.swain@ed.ac.uk> Date: Thu, 21 Dec 2023 19:01:56 +0000 Subject: [PATCH] feature(dataloader): added tmax_in_mins_dict --- dataloader.py | 159 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 117 insertions(+), 42 deletions(-) diff --git a/dataloader.py b/dataloader.py index 929093a..0ee623d 100644 --- a/dataloader.py +++ b/dataloader.py @@ -1,4 +1,6 @@ import pprint +from collections import OrderedDict +from operator import itemgetter from pathlib import Path import numpy as np @@ -6,9 +8,9 @@ import pandas as pd from wela.correct_buds import correct_buds - try: from postprocessor.grouper import Grouper + from wela.add_bud_fluorescence import add_bud_fluorescence except ModuleNotFoundError: print("Can only load tsv files - cannot find postprocessor.") @@ -154,6 +156,23 @@ class dataloader: grouper = Grouper(self.h5dirpath / dataname) return grouper + def fix_dictionaries(self, extra_g2a_dict, overwrite_dict): + """Update conversion dictionaries.""" + if extra_g2a_dict and not overwrite_dict: + self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict} + elif extra_g2a_dict: + self.g2a_dict = extra_g2a_dict + for key, value in zip( + [ + "postprocessing/buddings/extraction_general_None_volume", + "postprocessing/bud_metric/extraction_general_None_volume", + ], + ["buddings", "bud_volume"], + ): + if key not in self.g2a_dict: + self.g2a_dict[key] = value + self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()} + def load( self, dataname, @@ -166,17 +185,16 @@ class dataloader: overwrite_dict=False, hours=True, bud_fluorescence=False, + tmax_in_mins_dict=None, ): """ - Load either an experiment from h5 files or a tsv data set - into a long data frame. + Load either an experiment from h5 files or a tsv data set. + Data are stored in a long data frame with the .df attribute. The 'time' variable becomes a column. New data is added to the existing data frame. - Data is stored in the .df attribute. - Parameters ---------- dataname: string @@ -204,6 +222,11 @@ class dataloader: If True, convert times to hours (dividing by 60). bud_fluorescence: boolean If True, add mean and median bud fluorescence to the data frame. + tmax_in_mins_dict: dict (optional) + A dictionary with positions as keys and maximum times in minutes as + values. For example: { "PDR5_GFP_001": 6 * 60}. + Data will only be include up to this time point, which is a way to + avoid errors in assigning lineages because of clogging. Returns ------- @@ -223,36 +246,40 @@ class dataloader: if use_tsv: self.load_tsv(dataname) else: - # update dictionary - if extra_g2a_dict and not overwrite_dict: - self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict} - elif extra_g2a_dict: - self.g2a_dict = extra_g2a_dict - for key, value in zip( - [ - "postprocessing/buddings/extraction_general_None_volume", - "postprocessing/bud_metric/extraction_general_None_volume", - ], - ["buddings", "bud_volume"], - ): - if key not in self.g2a_dict: - self.g2a_dict[key] = value - self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()} + # update conversion dictionaries + self.fix_dictionaries(extra_g2a_dict, overwrite_dict) # create instance of grouper grouper = self.get_grouper(dataname) + # update tmax_in_mins_dict + if tmax_in_mins_dict: + tmax_in_mins_dict = self.generate_full_tmax_in_mins_dict( + grouper, tmax_in_mins_dict + ) print("\n---\n" + dataname + "\n---") if bud_fluorescence: + # call postprocessor to add bud fluorescence to h5 files self.include_bud_fluorescence(grouper, dataname) print("signals available:") for signal in grouper.available: print(" ", signal) # find time interval between images self.dt = grouper.tinterval - # get key index for choosing cells - key_index_path = self.a2g_dict[key_index] - key_index = self.get_key_index(grouper, key_index_path, cutoff) - # load from h5 files - r_df = self.load_h5(grouper, key_index, cutoff, interpolate_list) + # get key index for choosing cells and key-index data + index_for_key_index, r_df = self.get_key_index_data( + grouper, + key_index, + cutoff, + tmax_in_mins_dict, + ) + # load data from h5 files + tdf = self.load_h5( + grouper, + key_index, + index_for_key_index, + interpolate_list, + tmax_in_mins_dict, + ) + r_df = pd.merge(r_df, tdf, how="left") if pxsize: # convert volumes to micron^3 for signal in r_df.columns: @@ -290,12 +317,32 @@ class dataloader: # add bud fluorescence to h5 files add_bud_fluorescence(self.h5dirpath / dataname, signals) + def generate_full_tmax_in_mins_dict(self, grouper, tmax_in_mins_dict): + """ + Generate a tmax_in_mins_dict for all positions. + + The first position analysed must have the maximum number of time points + to ensure that merging data frames does not lose data. + """ + # define and sort tmax_in_mins_dict + full_dict = { + position: int(grouper.ntimepoints * grouper.tinterval) + for position in grouper.positions + } + tmax_in_mins_dict = {**full_dict, **tmax_in_mins_dict} + # sort to ensure that the dataframe is created with the longest time series + tmax_in_mins_dict = OrderedDict( + sorted(tmax_in_mins_dict.items(), key=itemgetter(1), reverse=True) + ) + return tmax_in_mins_dict + def load_h5( self, grouper, key_index, - cutoff, + index_for_key_index, interpolate_list, + tmax_in_mins_dict, ): """Load data from h5 files into one long data frame.""" print("\nLoading...") @@ -303,15 +350,17 @@ class dataloader: print(" bud data") r_df = self.load_bud_data( grouper, - cutoff, figs=False, - key_index=key_index, + index_for_key_index=index_for_key_index, interpolate_list=interpolate_list, + tmax_in_mins_dict=tmax_in_mins_dict, ) # load other signals for i, sigpath in enumerate(self.g2a_dict): - if sigpath in grouper.available and not ( - "buddings" in sigpath or "bud_metric" in sigpath + if ( + sigpath in grouper.available + and not ("buddings" in sigpath or "bud_metric" in sigpath) + and sigpath != self.a2g_dict[key_index] ): print(" " + sigpath) # load all cells @@ -319,9 +368,14 @@ class dataloader: mode = "raw" else: mode = "mothers" - record = grouper.concat_signal(sigpath, cutoff=0, mode=mode) + record = grouper.concat_signal( + sigpath, + cutoff=0, + mode=mode, + tmax_in_mins_dict=tmax_in_mins_dict, + ) # keep cells only in key_index - new_record = self.get_key_cells(record, key_index) + new_record = self.get_key_cells(record, index_for_key_index) # interpolate to remove internal NaNs for signals from mothers if ( interpolate_list @@ -334,26 +388,41 @@ class dataloader: r_df = pd.merge(r_df, tdf, how="left") return r_df - def get_key_index(self, grouper, key_index_path, cutoff): + def get_key_index_data( + self, grouper, key_index, cutoff, tmax_in_mins_dict + ): """ - Find index of cells that appear in the key record. + Find index and data for cells that appear in the key record. Cells must be retained at least a cutoff fraction of the experiment's duration. """ - record = grouper.concat_signal(key_index_path, cutoff=cutoff) + key_index_path = self.a2g_dict[key_index] + record = grouper.concat_signal( + key_index_path, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict + ) if record is not None: - return record.index + r_df = self.long_df_with_id(record, key_index) + return record.index, r_df else: raise Exception(f"{key_index_path} cannot be found.") def get_key_cells(self, df, key_index): - """Find a smaller data frame with only cells from the key record.""" + """ + Find a smaller multi-index data frame. + + The data frame will only have cells from the key record. + """ sdf = df.loc[df.index.intersection(key_index)] return sdf def load_bud_data( - self, grouper, cutoff, figs, key_index, interpolate_list + self, + grouper, + figs, + index_for_key_index, + interpolate_list, + tmax_in_mins_dict, ): """ Load buddings, bud volume, and any other bud signals. @@ -377,11 +446,17 @@ class dataloader: bud_interpolate_indices = None # load buddings buddings = grouper.concat_signal( - self.a2g_dict["buddings"], cutoff=cutoff + self.a2g_dict["buddings"], + cutoff=0, + tmax_in_mins_dict=tmax_in_mins_dict, ) # bud_volume and any other signals; missing signals return None bud_data = [ - grouper.concat_signal(self.a2g_dict[bud_signal], cutoff=0) + grouper.concat_signal( + self.a2g_dict[bud_signal], + cutoff=0, + tmax_in_mins_dict=tmax_in_mins_dict, + ) for bud_signal in bud_signals ] # perform correction @@ -389,9 +464,9 @@ class dataloader: buddings, bud_data, figs, bud_interpolate_indices ) # keep cells only in key_index - new_buddings = self.get_key_cells(new_buddings, key_index) + new_buddings = self.get_key_cells(new_buddings, index_for_key_index) new_bud_data = [ - self.get_key_cells(new_bud_dataset, key_index) + self.get_key_cells(new_bud_dataset, index_for_key_index) if new_bud_dataset is not None else None for new_bud_dataset in new_bud_data -- GitLab