From faf0b6f5ef110e7a4bb2c09768ca070e12fa1d74 Mon Sep 17 00:00:00 2001
From: pswain <peter.swain@ed.ac.uk>
Date: Fri, 22 Dec 2023 18:01:29 +0000
Subject: [PATCH] before refactoring key_index

---
 correct_buds.py |   1 +
 dataloader.py   | 134 +++++++++++++++++++-----------------------------
 2 files changed, 54 insertions(+), 81 deletions(-)

diff --git a/correct_buds.py b/correct_buds.py
index 201f705..c56ff80 100644
--- a/correct_buds.py
+++ b/correct_buds.py
@@ -94,6 +94,7 @@ def skip_buddings(buddings, bud_data):
                 else:
                     # ignore later budding
                     new_buddings.loc[ind].iloc[ib_end] = 0
+    print()
     return new_buddings, new_bud_data
 
 
diff --git a/dataloader.py b/dataloader.py
index 0ee623d..6640ac5 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -1,6 +1,4 @@
 import pprint
-from collections import OrderedDict
-from operator import itemgetter
 from pathlib import Path
 
 import numpy as np
@@ -156,7 +154,7 @@ class dataloader:
         grouper = Grouper(self.h5dirpath / dataname)
         return grouper
 
-    def fix_dictionaries(self, extra_g2a_dict, overwrite_dict):
+    def update_dictionaries(self, extra_g2a_dict, overwrite_dict):
         """Update conversion dictionaries."""
         if extra_g2a_dict and not overwrite_dict:
             self.g2a_dict = {**self.g2a_dict, **extra_g2a_dict}
@@ -173,6 +171,25 @@ class dataloader:
                 self.g2a_dict[key] = value
         self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()}
 
+    def include_bud_fluorescence(self, grouper, dataname):
+        """Add mean and median bud fluorescence to the h5 files."""
+        # find fluorescence channels
+        channels = list(grouper.channels)
+        channels.remove("Brightfield")
+        signals = [
+            signal
+            for two_signal in [
+                [
+                    f"/extraction/{channel}/max/median",
+                    f"/extraction/{channel}/max/mean",
+                ]
+                for channel in channels
+            ]
+            for signal in two_signal
+        ]
+        # add bud fluorescence to h5 files
+        add_bud_fluorescence(self.h5dirpath / dataname, signals)
+
     def load(
         self,
         dataname,
@@ -247,14 +264,9 @@ class dataloader:
             self.load_tsv(dataname)
         else:
             # update conversion dictionaries
-            self.fix_dictionaries(extra_g2a_dict, overwrite_dict)
+            self.update_dictionaries(extra_g2a_dict, overwrite_dict)
             # create instance of grouper
             grouper = self.get_grouper(dataname)
-            # update tmax_in_mins_dict
-            if tmax_in_mins_dict:
-                tmax_in_mins_dict = self.generate_full_tmax_in_mins_dict(
-                    grouper, tmax_in_mins_dict
-                )
             print("\n---\n" + dataname + "\n---")
             if bud_fluorescence:
                 # call postprocessor to add bud fluorescence to h5 files
@@ -262,24 +274,25 @@ class dataloader:
             print("signals available:")
             for signal in grouper.available:
                 print(" ", signal)
+            print()
             # find time interval between images
             self.dt = grouper.tinterval
-            # get key index for choosing cells and key-index data
+            # get multi index for choosing cells and key_index data
             index_for_key_index, r_df = self.get_key_index_data(
-                grouper,
-                key_index,
-                cutoff,
-                tmax_in_mins_dict,
+                grouper=grouper,
+                key_index=key_index,
+                cutoff=cutoff,
+                tmax_in_mins_dict=tmax_in_mins_dict,
             )
-            # load data from h5 files
-            tdf = self.load_h5(
-                grouper,
-                key_index,
-                index_for_key_index,
-                interpolate_list,
-                tmax_in_mins_dict,
+            # add data for other signals to data for key_index
+            r_df = self.load_h5(
+                grouper=grouper,
+                key_index=key_index,
+                index_for_key_index=index_for_key_index,
+                r_df=r_df,
+                interpolate_list=interpolate_list,
+                tmax_in_mins_dict=tmax_in_mins_dict,
             )
-            r_df = pd.merge(r_df, tdf, how="left")
             if pxsize:
                 # convert volumes to micron^3
                 for signal in r_df.columns:
@@ -292,69 +305,22 @@ class dataloader:
                 self.df = pd.merge(self.df, r_df, how="left")
             else:
                 self.df = r_df
-        print(f" data size is {self.df.shape}")
+        print(f"\n data size is {self.df.shape}")
         # define ids
         self.ids = list(self.df.id.unique())
         if not use_tsv:
             return grouper
 
-    def include_bud_fluorescence(self, grouper, dataname):
-        """Add mean and median bud fluorescence to the h5 files."""
-        # find fluorescence channels
-        channels = list(grouper.channels)
-        channels.remove("Brightfield")
-        signals = [
-            signal
-            for two_signal in [
-                [
-                    f"/extraction/{channel}/max/median",
-                    f"/extraction/{channel}/max/mean",
-                ]
-                for channel in channels
-            ]
-            for signal in two_signal
-        ]
-        # add bud fluorescence to h5 files
-        add_bud_fluorescence(self.h5dirpath / dataname, signals)
-
-    def generate_full_tmax_in_mins_dict(self, grouper, tmax_in_mins_dict):
-        """
-        Generate a tmax_in_mins_dict for all positions.
-
-        The first position analysed must have the maximum number of time points
-        to ensure that merging data frames does not lose data.
-        """
-        # define and sort tmax_in_mins_dict
-        full_dict = {
-            position: int(grouper.ntimepoints * grouper.tinterval)
-            for position in grouper.positions
-        }
-        tmax_in_mins_dict = {**full_dict, **tmax_in_mins_dict}
-        # sort to ensure that the dataframe is created with the longest time series
-        tmax_in_mins_dict = OrderedDict(
-            sorted(tmax_in_mins_dict.items(), key=itemgetter(1), reverse=True)
-        )
-        return tmax_in_mins_dict
-
     def load_h5(
         self,
         grouper,
         key_index,
         index_for_key_index,
+        r_df,
         interpolate_list,
         tmax_in_mins_dict,
     ):
         """Load data from h5 files into one long data frame."""
-        print("\nLoading...")
-        # load and correct buddings and bud_volume
-        print(" bud data")
-        r_df = self.load_bud_data(
-            grouper,
-            figs=False,
-            index_for_key_index=index_for_key_index,
-            interpolate_list=interpolate_list,
-            tmax_in_mins_dict=tmax_in_mins_dict,
-        )
         # load other signals
         for i, sigpath in enumerate(self.g2a_dict):
             if (
@@ -363,19 +329,14 @@ class dataloader:
                 and sigpath != self.a2g_dict[key_index]
             ):
                 print(" " + sigpath)
-                # load all cells
-                if "cy5" in sigpath:
-                    mode = "raw"
-                else:
-                    mode = "mothers"
                 record = grouper.concat_signal(
                     sigpath,
                     cutoff=0,
-                    mode=mode,
                     tmax_in_mins_dict=tmax_in_mins_dict,
                 )
                 # keep cells only in key_index
                 new_record = self.get_key_cells(record, index_for_key_index)
+                new_record = record
                 # interpolate to remove internal NaNs for signals from mothers
                 if (
                     interpolate_list
@@ -386,13 +347,23 @@ class dataloader:
                 tdf = self.long_df_with_id(new_record, self.g2a_dict[sigpath])
                 # merge into one data set
                 r_df = pd.merge(r_df, tdf, how="left")
+        print("\n Loading bud data.")
+        # load and correct buddings and bud_volume
+        b_df = self.load_bud_data(
+            grouper=grouper,
+            figs=False,
+            index_for_key_index=index_for_key_index,
+            interpolate_list=interpolate_list,
+            tmax_in_mins_dict=tmax_in_mins_dict,
+        )
+        r_df = pd.merge(r_df, b_df, how="left")
         return r_df
 
     def get_key_index_data(
         self, grouper, key_index, cutoff, tmax_in_mins_dict
     ):
         """
-        Find index and data for cells that appear in the key record.
+        Find multi-index and data for the key record.
 
         Cells must be retained at least a cutoff fraction of the
         experiment's duration.
@@ -402,18 +373,19 @@ class dataloader:
             key_index_path, cutoff=cutoff, tmax_in_mins_dict=tmax_in_mins_dict
         )
         if record is not None:
+            index_for_key_index = record.index
             r_df = self.long_df_with_id(record, key_index)
-            return record.index, r_df
+            return index_for_key_index, r_df
         else:
             raise Exception(f"{key_index_path} cannot be found.")
 
-    def get_key_cells(self, df, key_index):
+    def get_key_cells(self, df, index_for_key_index):
         """
         Find a smaller multi-index data frame.
 
         The data frame will only have cells from the key record.
         """
-        sdf = df.loc[df.index.intersection(key_index)]
+        sdf = df.loc[df.index.intersection(index_for_key_index)]
         return sdf
 
     def load_bud_data(
-- 
GitLab