Skip to content
Snippets Groups Projects
Commit 12d0f2e5 authored by pswain's avatar pswain
Browse files

feature(dataloader): added tmax_in_mins_dict

parent e724a177
No related branches found
No related tags found
No related merge requests found
import pprint import pprint
from collections import OrderedDict
from operator import itemgetter
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -6,9 +8,9 @@ import pandas as pd ...@@ -6,9 +8,9 @@ import pandas as pd
from wela.correct_buds import correct_buds from wela.correct_buds import correct_buds
try: try:
from postprocessor.grouper import Grouper from postprocessor.grouper import Grouper
from wela.add_bud_fluorescence import add_bud_fluorescence from wela.add_bud_fluorescence import add_bud_fluorescence
except ModuleNotFoundError: except ModuleNotFoundError:
print("Can only load tsv files - cannot find postprocessor.") print("Can only load tsv files - cannot find postprocessor.")
...@@ -154,6 +156,23 @@ class dataloader: ...@@ -154,6 +156,23 @@ class dataloader:
grouper = Grouper(self.h5dirpath / dataname) grouper = Grouper(self.h5dirpath / dataname)
return grouper return grouper
def fix_dictionaries(self, extra_g2a_dict, overwrite_dict):
"""Update conversion dictionaries."""
if extra_g2a_dict and not overwrite_dict:
self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict}
elif extra_g2a_dict:
self.g2a_dict = extra_g2a_dict
for key, value in zip(
[
"postprocessing/buddings/extraction_general_None_volume",
"postprocessing/bud_metric/extraction_general_None_volume",
],
["buddings", "bud_volume"],
):
if key not in self.g2a_dict:
self.g2a_dict[key] = value
self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()}
def load( def load(
self, self,
dataname, dataname,
...@@ -166,17 +185,16 @@ class dataloader: ...@@ -166,17 +185,16 @@ class dataloader:
overwrite_dict=False, overwrite_dict=False,
hours=True, hours=True,
bud_fluorescence=False, bud_fluorescence=False,
tmax_in_mins_dict=None,
): ):
""" """
Load either an experiment from h5 files or a tsv data set Load either an experiment from h5 files or a tsv data set.
into a long data frame.
Data are stored in a long data frame with the .df attribute.
The 'time' variable becomes a column. The 'time' variable becomes a column.
New data is added to the existing data frame. New data is added to the existing data frame.
Data is stored in the .df attribute.
Parameters Parameters
---------- ----------
dataname: string dataname: string
...@@ -204,6 +222,11 @@ class dataloader: ...@@ -204,6 +222,11 @@ class dataloader:
If True, convert times to hours (dividing by 60). If True, convert times to hours (dividing by 60).
bud_fluorescence: boolean bud_fluorescence: boolean
If True, add mean and median bud fluorescence to the data frame. If True, add mean and median bud fluorescence to the data frame.
tmax_in_mins_dict: dict (optional)
A dictionary with positions as keys and maximum times in minutes as
values. For example: { "PDR5_GFP_001": 6 * 60}.
Data will only be include up to this time point, which is a way to
avoid errors in assigning lineages because of clogging.
Returns Returns
------- -------
...@@ -223,36 +246,40 @@ class dataloader: ...@@ -223,36 +246,40 @@ class dataloader:
if use_tsv: if use_tsv:
self.load_tsv(dataname) self.load_tsv(dataname)
else: else:
# update dictionary # update conversion dictionaries
if extra_g2a_dict and not overwrite_dict: self.fix_dictionaries(extra_g2a_dict, overwrite_dict)
self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict}
elif extra_g2a_dict:
self.g2a_dict = extra_g2a_dict
for key, value in zip(
[
"postprocessing/buddings/extraction_general_None_volume",
"postprocessing/bud_metric/extraction_general_None_volume",
],
["buddings", "bud_volume"],
):
if key not in self.g2a_dict:
self.g2a_dict[key] = value
self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()}
# create instance of grouper # create instance of grouper
grouper = self.get_grouper(dataname) grouper = self.get_grouper(dataname)
# update tmax_in_mins_dict
if tmax_in_mins_dict:
tmax_in_mins_dict = self.generate_full_tmax_in_mins_dict(
grouper, tmax_in_mins_dict
)
print("\n---\n" + dataname + "\n---") print("\n---\n" + dataname + "\n---")
if bud_fluorescence: if bud_fluorescence:
# call postprocessor to add bud fluorescence to h5 files
self.include_bud_fluorescence(grouper, dataname) self.include_bud_fluorescence(grouper, dataname)
print("signals available:") print("signals available:")
for signal in grouper.available: for signal in grouper.available:
print(" ", signal) print(" ", signal)
# find time interval between images # find time interval between images
self.dt = grouper.tinterval self.dt = grouper.tinterval
# get key index for choosing cells # get key index for choosing cells and key-index data
key_index_path = self.a2g_dict[key_index] index_for_key_index, r_df = self.get_key_index_data(
key_index = self.get_key_index(grouper, key_index_path, cutoff) grouper,
# load from h5 files key_index,
r_df = self.load_h5(grouper, key_index, cutoff, interpolate_list) cutoff,
tmax_in_mins_dict,
)
# load data from h5 files
tdf = self.load_h5(
grouper,
key_index,
index_for_key_index,
interpolate_list,
tmax_in_mins_dict,
)
r_df = pd.merge(r_df, tdf, how="left")
if pxsize: if pxsize:
# convert volumes to micron^3 # convert volumes to micron^3
for signal in r_df.columns: for signal in r_df.columns:
...@@ -290,12 +317,32 @@ class dataloader: ...@@ -290,12 +317,32 @@ class dataloader:
# add bud fluorescence to h5 files # add bud fluorescence to h5 files
add_bud_fluorescence(self.h5dirpath / dataname, signals) add_bud_fluorescence(self.h5dirpath / dataname, signals)
def generate_full_tmax_in_mins_dict(self, grouper, tmax_in_mins_dict):
"""
Generate a tmax_in_mins_dict for all positions.
The first position analysed must have the maximum number of time points
to ensure that merging data frames does not lose data.
"""
# define and sort tmax_in_mins_dict
full_dict = {
position: int(grouper.ntimepoints * grouper.tinterval)
for position in grouper.positions
}
tmax_in_mins_dict = {**full_dict, **tmax_in_mins_dict}
# sort to ensure that the dataframe is created with the longest time series
tmax_in_mins_dict = OrderedDict(
sorted(tmax_in_mins_dict.items(), key=itemgetter(1), reverse=True)
)
return tmax_in_mins_dict
def load_h5( def load_h5(
self, self,
grouper, grouper,
key_index, key_index,
cutoff, index_for_key_index,
interpolate_list, interpolate_list,
tmax_in_mins_dict,
): ):
"""Load data from h5 files into one long data frame.""" """Load data from h5 files into one long data frame."""
print("\nLoading...") print("\nLoading...")
...@@ -303,15 +350,17 @@ class dataloader: ...@@ -303,15 +350,17 @@ class dataloader:
print(" bud data") print(" bud data")
r_df = self.load_bud_data( r_df = self.load_bud_data(
grouper, grouper,
cutoff,
figs=False, figs=False,
key_index=key_index, index_for_key_index=index_for_key_index,
interpolate_list=interpolate_list, interpolate_list=interpolate_list,
tmax_in_mins_dict=tmax_in_mins_dict,
) )
# load other signals # load other signals
for i, sigpath in enumerate(self.g2a_dict): for i, sigpath in enumerate(self.g2a_dict):
if sigpath in grouper.available and not ( if (
"buddings" in sigpath or "bud_metric" in sigpath sigpath in grouper.available
and not ("buddings" in sigpath or "bud_metric" in sigpath)
and sigpath != self.a2g_dict[key_index]
): ):
print(" " + sigpath) print(" " + sigpath)
# load all cells # load all cells
...@@ -319,9 +368,14 @@ class dataloader: ...@@ -319,9 +368,14 @@ class dataloader:
mode = "raw" mode = "raw"
else: else:
mode = "mothers" mode = "mothers"
record = grouper.concat_signal(sigpath, cutoff=0, mode=mode) record = grouper.concat_signal(
sigpath,
cutoff=0,
mode=mode,
tmax_in_mins_dict=tmax_in_mins_dict,
)
# keep cells only in key_index # keep cells only in key_index
new_record = self.get_key_cells(record, key_index) new_record = self.get_key_cells(record, index_for_key_index)
# interpolate to remove internal NaNs for signals from mothers # interpolate to remove internal NaNs for signals from mothers
if ( if (
interpolate_list interpolate_list
...@@ -334,26 +388,41 @@ class dataloader: ...@@ -334,26 +388,41 @@ class dataloader:
r_df = pd.merge(r_df, tdf, how="left") r_df = pd.merge(r_df, tdf, how="left")
return r_df return r_df
def get_key_index(self, grouper, key_index_path, cutoff): def get_key_index_data(
self, grouper, key_index, cutoff, tmax_in_mins_dict
):
""" """
Find index of cells that appear in the key record. Find index and data for cells that appear in the key record.
Cells must be retained at least a cutoff fraction of the Cells must be retained at least a cutoff fraction of the
experiment's duration. experiment's duration.
""" """
record = grouper.concat_signal(key_index_path, cutoff=cutoff) key_index_path = self.a2g_dict[key_index]
record = grouper.concat_signal(
key_index_path, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict
)
if record is not None: if record is not None:
return record.index r_df = self.long_df_with_id(record, key_index)
return record.index, r_df
else: else:
raise Exception(f"{key_index_path} cannot be found.") raise Exception(f"{key_index_path} cannot be found.")
def get_key_cells(self, df, key_index): def get_key_cells(self, df, key_index):
"""Find a smaller data frame with only cells from the key record.""" """
Find a smaller multi-index data frame.
The data frame will only have cells from the key record.
"""
sdf = df.loc[df.index.intersection(key_index)] sdf = df.loc[df.index.intersection(key_index)]
return sdf return sdf
def load_bud_data( def load_bud_data(
self, grouper, cutoff, figs, key_index, interpolate_list self,
grouper,
figs,
index_for_key_index,
interpolate_list,
tmax_in_mins_dict,
): ):
""" """
Load buddings, bud volume, and any other bud signals. Load buddings, bud volume, and any other bud signals.
...@@ -377,11 +446,17 @@ class dataloader: ...@@ -377,11 +446,17 @@ class dataloader:
bud_interpolate_indices = None bud_interpolate_indices = None
# load buddings # load buddings
buddings = grouper.concat_signal( buddings = grouper.concat_signal(
self.a2g_dict["buddings"], cutoff=cutoff self.a2g_dict["buddings"],
cutoff=0,
tmax_in_mins_dict=tmax_in_mins_dict,
) )
# bud_volume and any other signals; missing signals return None # bud_volume and any other signals; missing signals return None
bud_data = [ bud_data = [
grouper.concat_signal(self.a2g_dict[bud_signal], cutoff=0) grouper.concat_signal(
self.a2g_dict[bud_signal],
cutoff=0,
tmax_in_mins_dict=tmax_in_mins_dict,
)
for bud_signal in bud_signals for bud_signal in bud_signals
] ]
# perform correction # perform correction
...@@ -389,9 +464,9 @@ class dataloader: ...@@ -389,9 +464,9 @@ class dataloader:
buddings, bud_data, figs, bud_interpolate_indices buddings, bud_data, figs, bud_interpolate_indices
) )
# keep cells only in key_index # keep cells only in key_index
new_buddings = self.get_key_cells(new_buddings, key_index) new_buddings = self.get_key_cells(new_buddings, index_for_key_index)
new_bud_data = [ new_bud_data = [
self.get_key_cells(new_bud_dataset, key_index) self.get_key_cells(new_bud_dataset, index_for_key_index)
if new_bud_dataset is not None if new_bud_dataset is not None
else None else None
for new_bud_dataset in new_bud_data for new_bud_dataset in new_bud_data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment