From 12d0f2e5258d22f0107f3ae5511cfdcfc8a1d1a0 Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Thu, 21 Dec 2023 19:01:56 +0000
Subject: [PATCH] feature(dataloader): added tmax_in_mins_dict

---
 dataloader.py | 159 +++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 117 insertions(+), 42 deletions(-)

diff --git a/dataloader.py b/dataloader.py
index 929093a..0ee623d 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -1,4 +1,6 @@
 import pprint
+from collections import OrderedDict
+from operator import itemgetter
 from pathlib import Path
 
 import numpy as np
@@ -6,9 +8,9 @@ import pandas as pd
 
 from wela.correct_buds import correct_buds
 
-
 try:
     from postprocessor.grouper import Grouper
+
     from wela.add_bud_fluorescence import add_bud_fluorescence
 except ModuleNotFoundError:
     print("Can only load tsv files - cannot find postprocessor.")
@@ -154,6 +156,23 @@ class dataloader:
         grouper = Grouper(self.h5dirpath / dataname)
         return grouper
 
+    def fix_dictionaries(self, extra_g2a_dict, overwrite_dict):
+        """Update conversion dictionaries."""
+        if extra_g2a_dict and not overwrite_dict:
+            self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict}
+        elif extra_g2a_dict:
+            self.g2a_dict = extra_g2a_dict
+        for key, value in zip(
+            [
+                "postprocessing/buddings/extraction_general_None_volume",
+                "postprocessing/bud_metric/extraction_general_None_volume",
+            ],
+            ["buddings", "bud_volume"],
+        ):
+            if key not in self.g2a_dict:
+                self.g2a_dict[key] = value
+        self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()}
+
     def load(
         self,
         dataname,
@@ -166,17 +185,16 @@ class dataloader:
         overwrite_dict=False,
         hours=True,
         bud_fluorescence=False,
+        tmax_in_mins_dict=None,
     ):
         """
-        Load either an experiment from h5 files or a tsv data set
-        into a long data frame.
+        Load either an experiment from h5 files or a tsv data set.
 
+        Data are stored in a long data frame with the .df attribute.
         The 'time' variable becomes a column.
 
         New data is added to the existing data frame.
 
-        Data is stored in the .df attribute.
-
         Parameters
         ----------
         dataname: string
@@ -204,6 +222,11 @@ class dataloader:
             If True, convert times to hours (dividing by 60).
         bud_fluorescence: boolean
             If True, add mean and median bud fluorescence to the data frame.
+        tmax_in_mins_dict: dict (optional)
+            A dictionary with positions as keys and maximum times in minutes as
+            values. For example: { "PDR5_GFP_001": 6 * 60}.
+            Data will only be include up to this time point, which is a way to
+            avoid errors in assigning lineages because of clogging.
 
         Returns
         -------
@@ -223,36 +246,40 @@ class dataloader:
         if use_tsv:
             self.load_tsv(dataname)
         else:
-            # update dictionary
-            if extra_g2a_dict and not overwrite_dict:
-                self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict}
-            elif extra_g2a_dict:
-                self.g2a_dict = extra_g2a_dict
-            for key, value in zip(
-                [
-                    "postprocessing/buddings/extraction_general_None_volume",
-                    "postprocessing/bud_metric/extraction_general_None_volume",
-                ],
-                ["buddings", "bud_volume"],
-            ):
-                if key not in self.g2a_dict:
-                    self.g2a_dict[key] = value
-            self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()}
+            # update conversion dictionaries
+            self.fix_dictionaries(extra_g2a_dict, overwrite_dict)
             # create instance of grouper
             grouper = self.get_grouper(dataname)
+            # update tmax_in_mins_dict
+            if tmax_in_mins_dict:
+                tmax_in_mins_dict = self.generate_full_tmax_in_mins_dict(
+                    grouper, tmax_in_mins_dict
+                )
             print("\n---\n" + dataname + "\n---")
             if bud_fluorescence:
+                # call postprocessor to add bud fluorescence to h5 files
                 self.include_bud_fluorescence(grouper, dataname)
             print("signals available:")
             for signal in grouper.available:
                 print(" ", signal)
             # find time interval between images
             self.dt = grouper.tinterval
-            # get key index for choosing cells
-            key_index_path = self.a2g_dict[key_index]
-            key_index = self.get_key_index(grouper, key_index_path, cutoff)
-            # load from h5 files
-            r_df = self.load_h5(grouper, key_index, cutoff, interpolate_list)
+            # get key index for choosing cells and key-index data
+            index_for_key_index, r_df = self.get_key_index_data(
+                grouper,
+                key_index,
+                cutoff,
+                tmax_in_mins_dict,
+            )
+            # load data from h5 files
+            tdf = self.load_h5(
+                grouper,
+                key_index,
+                index_for_key_index,
+                interpolate_list,
+                tmax_in_mins_dict,
+            )
+            r_df = pd.merge(r_df, tdf, how="left")
             if pxsize:
                 # convert volumes to micron^3
                 for signal in r_df.columns:
@@ -290,12 +317,32 @@ class dataloader:
         # add bud fluorescence to h5 files
         add_bud_fluorescence(self.h5dirpath / dataname, signals)
 
+    def generate_full_tmax_in_mins_dict(self, grouper, tmax_in_mins_dict):
+        """
+        Generate a tmax_in_mins_dict for all positions.
+
+        The first position analysed must have the maximum number of time points
+        to ensure that merging data frames does not lose data.
+        """
+        # define and sort tmax_in_mins_dict
+        full_dict = {
+            position: int(grouper.ntimepoints * grouper.tinterval)
+            for position in grouper.positions
+        }
+        tmax_in_mins_dict = {**full_dict, **tmax_in_mins_dict}
+        # sort to ensure that the dataframe is created with the longest time series
+        tmax_in_mins_dict = OrderedDict(
+            sorted(tmax_in_mins_dict.items(), key=itemgetter(1), reverse=True)
+        )
+        return tmax_in_mins_dict
+
     def load_h5(
         self,
         grouper,
         key_index,
-        cutoff,
+        index_for_key_index,
         interpolate_list,
+        tmax_in_mins_dict,
     ):
         """Load data from h5 files into one long data frame."""
         print("\nLoading...")
@@ -303,15 +350,17 @@ class dataloader:
         print(" bud data")
         r_df = self.load_bud_data(
             grouper,
-            cutoff,
             figs=False,
-            key_index=key_index,
+            index_for_key_index=index_for_key_index,
             interpolate_list=interpolate_list,
+            tmax_in_mins_dict=tmax_in_mins_dict,
         )
         # load other signals
         for i, sigpath in enumerate(self.g2a_dict):
-            if sigpath in grouper.available and not (
-                "buddings" in sigpath or "bud_metric" in sigpath
+            if (
+                sigpath in grouper.available
+                and not ("buddings" in sigpath or "bud_metric" in sigpath)
+                and sigpath != self.a2g_dict[key_index]
             ):
                 print(" " + sigpath)
                 # load all cells
@@ -319,9 +368,14 @@ class dataloader:
                     mode = "raw"
                 else:
                     mode = "mothers"
-                record = grouper.concat_signal(sigpath, cutoff=0, mode=mode)
+                record = grouper.concat_signal(
+                    sigpath,
+                    cutoff=0,
+                    mode=mode,
+                    tmax_in_mins_dict=tmax_in_mins_dict,
+                )
                 # keep cells only in key_index
-                new_record = self.get_key_cells(record, key_index)
+                new_record = self.get_key_cells(record, index_for_key_index)
                 # interpolate to remove internal NaNs for signals from mothers
                 if (
                     interpolate_list
@@ -334,26 +388,41 @@ class dataloader:
                 r_df = pd.merge(r_df, tdf, how="left")
         return r_df
 
-    def get_key_index(self, grouper, key_index_path, cutoff):
+    def get_key_index_data(
+        self, grouper, key_index, cutoff, tmax_in_mins_dict
+    ):
         """
-        Find index of cells that appear in the key record.
+        Find index and data for cells that appear in the key record.
 
         Cells must be retained at least a cutoff fraction of the
         experiment's duration.
         """
-        record = grouper.concat_signal(key_index_path, cutoff=cutoff)
+        key_index_path = self.a2g_dict[key_index]
+        record = grouper.concat_signal(
+            key_index_path, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict
+        )
         if record is not None:
-            return record.index
+            r_df = self.long_df_with_id(record, key_index)
+            return record.index, r_df
         else:
             raise Exception(f"{key_index_path} cannot be found.")
 
     def get_key_cells(self, df, key_index):
-        """Find a smaller data frame with only cells from the key record."""
+        """
+        Find a smaller multi-index data frame.
+
+        The data frame will only have cells from the key record.
+        """
         sdf = df.loc[df.index.intersection(key_index)]
         return sdf
 
     def load_bud_data(
-        self, grouper, cutoff, figs, key_index, interpolate_list
+        self,
+        grouper,
+        figs,
+        index_for_key_index,
+        interpolate_list,
+        tmax_in_mins_dict,
     ):
         """
         Load buddings, bud volume, and any other bud signals.
@@ -377,11 +446,17 @@ class dataloader:
             bud_interpolate_indices = None
         # load buddings
         buddings = grouper.concat_signal(
-            self.a2g_dict["buddings"], cutoff=cutoff
+            self.a2g_dict["buddings"],
+            cutoff=0,
+            tmax_in_mins_dict=tmax_in_mins_dict,
         )
         # bud_volume and any other signals; missing signals return None
         bud_data = [
-            grouper.concat_signal(self.a2g_dict[bud_signal], cutoff=0)
+            grouper.concat_signal(
+                self.a2g_dict[bud_signal],
+                cutoff=0,
+                tmax_in_mins_dict=tmax_in_mins_dict,
+            )
             for bud_signal in bud_signals
         ]
         # perform correction
@@ -389,9 +464,9 @@ class dataloader:
             buddings, bud_data, figs, bud_interpolate_indices
         )
         # keep cells only in key_index
-        new_buddings = self.get_key_cells(new_buddings, key_index)
+        new_buddings = self.get_key_cells(new_buddings, index_for_key_index)
         new_bud_data = [
-            self.get_key_cells(new_bud_dataset, key_index)
+            self.get_key_cells(new_bud_dataset, index_for_key_index)
             if new_bud_dataset is not None
             else None
             for new_bud_dataset in new_bud_data
-- 
GitLab