From d6e76266740dec68bcfed15c5f7c764e383d6e87 Mon Sep 17 00:00:00 2001
From: Peter Swain <peter.swain@ed.ac.uk>
Date: Mon, 19 Feb 2024 15:18:03 +0000
Subject: [PATCH] change(processABC): _log to log

---
 src/agora/abc.py                           | 12 ++++---
 src/agora/io/bridge.py                     |  3 +-
 src/agora/io/cells.py                      | 10 +++---
 src/agora/io/signal.py                     |  5 ++-
 src/agora/io/writer.py                     | 18 +++++-----
 src/aliby/fullpipeline.py                  | 24 +++++++------
 src/aliby/pipeline.py                      | 40 +++++++++++++---------
 src/extraction/core/extractor.py           | 10 +++---
 src/postprocessor/core/reshapers/picker.py |  2 +-
 9 files changed, 69 insertions(+), 55 deletions(-)

diff --git a/src/agora/abc.py b/src/agora/abc.py
index 6f9a701..42c37db 100644
--- a/src/agora/abc.py
+++ b/src/agora/abc.py
@@ -46,9 +46,11 @@ class ParametersABC(ABC):
                 ]
             ):
                 return {
-                    k: v.to_dict()
-                    if hasattr(v, "to_dict")
-                    else self.to_dict(v)
+                    k: (
+                        v.to_dict()
+                        if hasattr(v, "to_dict")
+                        else self.to_dict(v)
+                    )
                     for k, v in iterable.items()
                 }
             else:
@@ -163,8 +165,8 @@ class ProcessABC(ABC):
     def run(self):
         pass
 
-    def _log(self, message: str, level: str = "warning"):
-        # Log messages in the corresponding level
+    def log(self, message: str, level: str = "warning"):
+        """Log messages at the corresponding level."""
         logger = logging.getLogger("aliby")
         getattr(logger, level)(f"{self.__class__.__name__}: {message}")
 
diff --git a/src/agora/io/bridge.py b/src/agora/io/bridge.py
index 3f44541..19fa218 100644
--- a/src/agora/io/bridge.py
+++ b/src/agora/io/bridge.py
@@ -1,6 +1,7 @@
 """
 Tools to interact with h5 files and handle data consistently.
 """
+
 import collections
 import logging
 import typing as t
@@ -28,7 +29,7 @@ class BridgeH5:
                 "cell_info" in self.hdf
             ), "Invalid file. No 'cell_info' found."
 
-    def _log(self, message: str, level: str = "warn"):
+    def log(self, message: str, level: str = "warn"):
         # Log messages in the corresponding level
         logger = logging.getLogger("aliby")
         getattr(logger, level)(f"{self.__class__.__name__}: {message}")
diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py
index 5656dcb..a8ea422 100644
--- a/src/agora/io/cells.py
+++ b/src/agora/io/cells.py
@@ -64,7 +64,7 @@ class Cells:
         """Ensure initiating file is a Path object."""
         return cls(Path(source))
 
-    def _log(self, message: str, level: str = "warn"):
+    def log(self, message: str, level: str = "warn"):
         """Log messages in the corresponding level."""
         logger = logging.getLogger("aliby")
         getattr(logger, level)(f"{self.__class__.__name__}: {message}")
@@ -273,9 +273,9 @@ class Cells:
             (self.ntraps, self["cell_label"].max(), self.ntimepoints),
             dtype=bool,
         )
-        ncells_mat[
-            self["trap"], self["cell_label"] - 1, self["timepoint"]
-        ] = True
+        ncells_mat[self["trap"], self["cell_label"] - 1, self["timepoint"]] = (
+            True
+        )
         return ncells_mat
 
     def cell_tp_where(
@@ -350,7 +350,7 @@ class Cells:
             )
         else:
             mothers_daughters = np.array([])
-            self._log("No mother-daughters assigned")
+            self.log("No mother-daughters assigned")
         return mothers_daughters
 
     @staticmethod
diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py
index e989ca2..d63746d 100644
--- a/src/agora/io/signal.py
+++ b/src/agora/io/signal.py
@@ -204,7 +204,7 @@ class Signal(BridgeH5):
             with h5py.File(self.filename, "r") as f:
                 f.visititems(self.store_signal_path)
         except Exception as e:
-            self._log("Exception when visiting h5: {}".format(e), "exception")
+            self.log("Exception when visiting h5: {}".format(e), "exception")
         return self._available
 
     def get_merged(self, dataset):
@@ -329,8 +329,7 @@ class Signal(BridgeH5):
     def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
         """Get data from h5 file as a dataframe."""
         if path not in f:
-            message = f"{path} not in {f}."
-            self._log(message)
+            self.log(f"{path} not in {f}.")
             return None
         else:
             dset = f[path]
diff --git a/src/agora/io/writer.py b/src/agora/io/writer.py
index b36cd4a..567b8da 100644
--- a/src/agora/io/writer.py
+++ b/src/agora/io/writer.py
@@ -55,7 +55,7 @@ class DynamicWriter:
         if Path(file).exists():
             self.metadata = load_meta(file)
 
-    def _log(self, message: str, level: str = "warn"):
+    def log(self, message: str, level: str = "warn"):
         # Log messages in the corresponding level
         logger = logging.getLogger("aliby")
         getattr(logger, level)(f"{self.__class__.__name__}: {message}")
@@ -104,9 +104,11 @@ class DynamicWriter:
                 maxshape=max_shape,
                 dtype=dtype,
                 compression=self.compression,
-                compression_opts=self.compression_opts
-                if self.compression is not None
-                else None,
+                compression_opts=(
+                    self.compression_opts
+                    if self.compression is not None
+                    else None
+                ),
             )
             # write all data, signified by the empty tuple
             hgroup[key][()] = data
@@ -174,7 +176,7 @@ class DynamicWriter:
                             # append or create new dataset
                             self._append(value, key, hgroup)
                     except Exception as e:
-                        self._log(
+                        self.log(
                             f"{key}:{value} could not be written: {e}", "error"
                         )
             # write metadata
@@ -622,9 +624,9 @@ class Writer(BridgeH5):
 
                 # sort indices for h5 indexing
                 incremental_existing = np.argsort(found_indices)
-                self.id_cache[df.index.nlevels][
-                    "found_indices"
-                ] = found_indices[incremental_existing]
+                self.id_cache[df.index.nlevels]["found_indices"] = (
+                    found_indices[incremental_existing]
+                )
                 self.id_cache[df.index.nlevels]["found_multi"] = found_multis[
                     incremental_existing
                 ]
diff --git a/src/aliby/fullpipeline.py b/src/aliby/fullpipeline.py
index 2b96210..c05bc3b 100644
--- a/src/aliby/fullpipeline.py
+++ b/src/aliby/fullpipeline.py
@@ -1,4 +1,5 @@
 """Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
+
 import logging
 import os
 import re
@@ -293,7 +294,7 @@ class Pipeline(ProcessABC):
         return cls(pipeline_parameters, store=directory)
 
     @property
-    def _logger(self):
+    def logger(self):
         return logging.getLogger("aliby")
 
     def run(self):
@@ -522,14 +523,14 @@ class Pipeline(ProcessABC):
                                 pipe["steps"]["tiler"].tile_size,
                             )
                             if frac_clogged_traps > 0.3:
-                                self._log(
+                                self.log(
                                     f"{name}:Clogged_traps:{frac_clogged_traps}"
                                 )
                                 frac = np.round(frac_clogged_traps * 100)
                                 progress_bar.set_postfix_str(f"{frac} Clogged")
                         else:
                             # stop if too many traps are clogged
-                            self._log(
+                            self.log(
                                 f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
                             )
                             pipe["meta"].add_fields({"end_status": "Clogged"})
@@ -541,7 +542,7 @@ class Pipeline(ProcessABC):
                         pipe["config"]["postprocessing"]
                     )
                     PostProcessor(pipe["filename"], post_proc_params).run()
-                    self._log("Analysis finished successfully.", "info")
+                    self.log("Analysis finished successfully.", "info")
                     return 1
         except Exception as e:
             # catch bugs during setup or run time
@@ -637,7 +638,7 @@ class Pipeline(ProcessABC):
                 os.remove(pipe["filename"])
             # if the file exists with no previous segmentation use its tiler
             if pipe["filename"].exists():
-                self._log("Result file exists.", "info")
+                self.log("Result file exists.", "info")
                 if not overwrite["tiler"]:
                     tiler_params_dict = TilerParameters.default().to_dict()
                     tiler_params_dict["position_name"] = name.split(".")[0]
@@ -668,7 +669,7 @@ class Pipeline(ProcessABC):
                             "tiler"
                         ].parameters.to_dict()
                     except Exception:
-                        self._log("Overwriting tiling data")
+                        self.log("Overwriting tiling data")
 
             if config["general"]["use_explog"]:
                 pipe["meta"].run()
@@ -679,9 +680,11 @@ class Pipeline(ProcessABC):
                     "aliby_version": version("aliby"),
                     "baby_version": version("aliby-baby"),
                     "omero_id": config["general"]["id"],
-                    "image_id": image_id
-                    if isinstance(image_id, int)
-                    else str(image_id),
+                    "image_id": (
+                        image_id
+                        if isinstance(image_id, int)
+                        else str(image_id)
+                    ),
                     "parameters": PipelineParameters.from_dict(
                         config
                     ).to_yaml(),
@@ -727,8 +730,7 @@ def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
     )
     # find tiles with cells covering too great a fraction of the tiles' area
     traps_above_athresh = (
-        cells_used.groupby("trap").sum().apply(np.mean, axis=1)
-        / tile_size**2
+        cells_used.groupby("trap").sum().apply(np.mean, axis=1) / tile_size**2
         > es_parameters["thresh_trap_area"]
     )
     return (traps_above_nthresh & traps_above_athresh).mean()
diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py
index eb7932b..24a19d7 100644
--- a/src/aliby/pipeline.py
+++ b/src/aliby/pipeline.py
@@ -1,6 +1,7 @@
 """Set up and run pipelines: tiling, segmentation, extraction, and then post-processing."""
 
 import logging
+import multiprocessing
 import os
 import re
 import typing as t
@@ -10,7 +11,6 @@ from pprint import pprint
 import baby
 import baby.errors
 import numpy as np
-import multiprocessing
 import tensorflow as tf
 from pathos.multiprocessing import Pool
 from tqdm import tqdm
@@ -207,20 +207,16 @@ class Pipeline(ProcessABC):
         """Get meta data and identify each position."""
         config = self.parameters.to_dict()
         # print configuration
-        logging.getLogger("aliby").info(f"Using alibylite.")
+        self.log("Using alibylite.", "info")
         try:
-            logging.getLogger("aliby").info(f"Using Baby {baby.__version__}.")
+            self.log(f"Using Baby {baby.__version__}.", "info")
         except AttributeError:
-            logging.getLogger("aliby").info("Using original Baby.")
-        for step in config:
-            print("\n---\n" + step + "\n---")
-            pprint(config[step])
-        print()
+            self.log("Using original Baby.", "info")
         # extract from configuration
         root_dir = Path(config["general"]["directory"])
         dispatcher = dispatch_dataset(self.expt_id, **self.server_info)
-        logging.getLogger("aliby").info(
-            f"Fetching data using {dispatcher.__class__.__name__}."
+        self.log(
+            f"Fetching data using {dispatcher.__class__.__name__}.", "info"
         )
         # get log files, either locally or via OMERO
         with dispatcher as conn:
@@ -294,7 +290,7 @@ class Pipeline(ProcessABC):
         if not len(position_ids):
             raise Exception("No images to segment.")
         else:
-            print("\nPositions selected:")
+            print("Positions selected:")
             for pos in position_ids:
                 print("\t" + pos.split(".")[0])
         print(f"Number of CPU cores available: {multiprocessing.cpu_count()}")
@@ -381,14 +377,15 @@ class Pipeline(ProcessABC):
                         meta={"last_processed:": i},
                     )
                     if i == 0:
-                        logging.getLogger("aliby").info(
-                            f"Found {tiler.no_tiles} traps in {image.name}"
+                        self.log(
+                            f"Found {tiler.no_tiles} traps in {image.name}.",
+                            "info",
                         )
                     # run Baby
                     try:
                         result = babyrunner.run_tp(i)
                     except baby.errors.Clogging:
-                        logging.getLogger("aliby").warn(
+                        self.log(
                             "WARNING:Clogging threshold exceeded in BABY."
                         )
                     baby_writer.write(
@@ -411,12 +408,12 @@ class Pipeline(ProcessABC):
                         tiler.tile_size,
                     )
                     if frac_clogged_traps > 0.3:
-                        self._log(f"{name}:Clogged_traps:{frac_clogged_traps}")
+                        self.log(f"{name}:Clogged_traps:{frac_clogged_traps}")
                         frac = np.round(frac_clogged_traps * 100)
                         progress_bar.set_postfix_str(f"{frac} Clogged")
                 else:
                     # stop if too many clogged traps
-                    self._log(
+                    self.log(
                         f"{name}:Stopped early at time {i} with {frac_clogged_traps} clogged traps"
                     )
                     break
@@ -425,9 +422,18 @@ class Pipeline(ProcessABC):
                 out_file,
                 PostProcessorParameters.from_dict(config["postprocessing"]),
             ).run()
-            self._log("Analysis finished successfully.", "info")
+            self.log("Analysis finished successfully.", "info")
             return 1
 
+    @property
+    def display_config(self):
+        """Show all parameters for each step of the pipeline."""
+        config = self.parameters.to_dict()
+        for step in config:
+            print("\n---\n" + step + "\n---")
+            pprint(config[step])
+        print()
+
 
 def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
     """
diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py
index 44677b3..fa379e4 100644
--- a/src/extraction/core/extractor.py
+++ b/src/extraction/core/extractor.py
@@ -335,7 +335,7 @@ class Extractor(StepABC):
             the tile_id and cell labels
         """
         if cell_labels is None:
-            self._log("No cell labels given. Sorting cells using index.")
+            self.log("No cell labels given. Sorting cells using index.")
         cell_fun = True if cell_function in self.all_cell_funs else False
         idx = []
         results = []
@@ -501,9 +501,11 @@ class Extractor(StepABC):
                 list(
                     map(
                         # sum over masks for each cell
-                        lambda x: np.sum(x, axis=0)
-                        if np.any(x)
-                        else np.zeros((tile_size, tile_size)),
+                        lambda x: (
+                            np.sum(x, axis=0)
+                            if np.any(x)
+                            else np.zeros((tile_size, tile_size))
+                        ),
                         masks,
                     )
                 )
diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py
index ccd5869..d9942e8 100644
--- a/src/postprocessor/core/reshapers/picker.py
+++ b/src/postprocessor/core/reshapers/picker.py
@@ -107,7 +107,7 @@ class Picker(LineageProcess):
                 # number of indices reduces for each iteration of the loop
                 indices = indices.intersection(new_indices)
         else:
-            self._log("No lineage assignment")
+            self.log("No lineage assignment")
             indices = np.array([])
         # return as list
         indices_arr = [tuple(x) for x in indices]
-- 
GitLab