#!/usr/bin/env python3 import typing as t from abc import ABC from collections import Counter from functools import cached_property as property from pathlib import Path from typing import Union import h5py import numpy as np import pandas as pd from pathos.multiprocessing import Pool from agora.io.signal import Signal class Grouper(ABC): """Base grouper class.""" def __init__(self, dir: Union[str, Path], name_inds=(0, -4)): """Find h5 files and load each one.""" path = Path(dir) assert path.exists(), f"{str(dir)} does not exist" self.name = path.name self.files = list(path.glob("*.h5")) assert len(self.files), "No valid h5 files in dir" self.load_positions() self.positions_groups = { name: name[name_inds[0] : name_inds[1]] for name in self.positions.keys() } def load_positions(self) -> None: """Load a Signal for each position, or h5 file.""" self.positions = {f.name[:-3]: Signal(f) for f in sorted(self.files)} @property def first_signal(self) -> Signal: """Get Signal for the first position.""" return list(self.positions.values())[0] @property def ntimepoints(self) -> int: """Find number of time points.""" return max([s.ntimepoints for s in self.positions.values()]) @property def tinterval(self) -> float: """Find the time interval for all positions.""" tintervals = list( np.unique([s.tinterval / 60 for s in self.positions.values()]) ) assert ( len(tintervals) == 1 ), "Not all positions have the same time interval." return tintervals[0] @property def available(self) -> t.Collection[str]: """Generate list of available signals from the first position.""" return self.first_signal.available @property def available_grouped(self) -> None: """Display available signals and the number of positions with these signals.""" if not hasattr(self, "available_grouped"): self._available_grouped = Counter( [x for s in self.positions.values() for x in s.available] ) for s, n in self._available_grouped.items(): print(f"{s} - {n}") def concat_signal( self, path: str, pool: t.Optional[int] = None, mode: str = "retained", selected_positions: t.List[str] = None, tmax_in_mins_dict: dict = None, **kwargs, ): """ Concatenate data for one signal from different h5 files. Each h5 files corresponds to a different position. Parameters ---------- path : str Signal location within h5 file. pool : int (optional) Number of threads used; if 0 or None only one core is used. mode: str If "retained" (default), return Signal with merging, picking, and lineage information applied but only for cells present for at least some cutoff fraction of the movie. If "raw", return Signal without merging, picking, lineage information, or a cutoff applied. Each of the first three options can be re-selected. A raw Signal with all three selected is the same as a retained Signal with a 0 cutoff. If "daughters", return Signal with only daughters - cells with an identified mother. If "families", get Signal with merging, picking, and lineage information applied. selected_positions: list[str] (optional) If defined, get signals for only these positions. tmax_in_mins_dict: dict (optional) A dictionary with positions as keys and maximum times in minutes as values. For example: { "PDR5_GFP_001": 6 * 60}. Data will only be include up to this time point, which is a way to avoid errors in assigning lineages because of clogging. **kwargs : key, value pairings Named arguments for concat_ind_function Examples -------- >>> record = grouper.concat_signal("extraction/GFP/max/median") """ if path.startswith("/"): path = path.strip("/") good_positions = self.filter_positions(path) if selected_positions is not None: good_positions = { key: value for key, value in good_positions.items() if key in selected_positions } if good_positions: kwargs["mode"] = mode records = self.pool_function( path=path, f=concat_one_signal, pool=pool, positions=good_positions, tmax_in_mins_dict=tmax_in_mins_dict, **kwargs, ) # check for errors errors = [ position for record, position in zip(records, good_positions.keys()) if record is None ] records = [record for record in records if record is not None] if len(errors): print(f"Warning: Positions ({errors}) contain errors.") assert len(records), "All data sets contain errors" # combine into one data frame concat = pd.concat(records, axis=0) if len(concat.index.names) > 4: # reorder levels in the multi-index data frame # when mother_label is present concat = concat.reorder_levels( ("group", "position", "trap", "cell_label", "mother_label") ) concat_sorted = concat.sort_index() return concat_sorted else: print("No data found.") def filter_positions(self, path: str) -> t.Dict[str, Signal]: """Filter positions to those whose data is available in the h5 file.""" good_positions = { k: v for k, v in self.positions.items() if path in [*v.available] } no_positions_dif = len(self.positions) - len(good_positions) if no_positions_dif: print( f"Grouper:Warning: some positions ({no_positions_dif}) do not" f" contain {path}." ) return good_positions def pool_function( self, path: str, f: t.Callable, pool: t.Optional[int] = None, positions: t.Dict[str, Signal] = None, tmax_in_mins_dict: dict = None, **kwargs, ): """ Enable different threads for different positions. Particularly useful when aggregating multiple elements. """ positions = positions or self.positions if pool: with Pool(pool) as p: records = p.map( lambda x: f( path=path, position=x[1], group=self.positions_groups[x[0]], position_name=x[0], tmax_in_mins_dict=tmax_in_mins_dict, **kwargs, ), positions.items(), ) else: records = [ f( path=path, position=position, group=self.positions_groups[name], position_name=name, tmax_in_mins_dict=tmax_in_mins_dict, **kwargs, ) for name, position in positions.items() ] return records @property def no_tiles(self): """Get total number of tiles per position (h5 file).""" for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: print(pos, f["/trap_info/trap_locations"].shape[0]) @property def tile_locs(self) -> t.Dict[str, np.ndarray]: """Get the locations of the tiles for each position as a dictionary.""" d = {} for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: d[pos] = f["/trap_info/trap_locations"][()] return d def no_cells( self, path="extraction/general/None/area", mode="retained", **kwargs, ) -> t.Dict[str, int]: """Get number of cells retained per position in base channel as a dictionary.""" return ( self.concat_signal(path=path, mode=mode, **kwargs) .groupby("group") .apply(len) .to_dict() ) @property def no_retained(self) -> t.Dict[str, int]: """Get number of cells retained per position in base channel as a dictionary.""" return self.no_cells() @property def channels(self): """Get channels available over all positions as a set.""" return set( [ channel for position in self.positions.values() for channel in position.channels ] ) @property def no_members(self) -> t.Dict[str, int]: """Get the number of positions belonging to each group.""" return Counter(self.positions_groups.values()) @property def no_tiles_by_group(self) -> t.Dict[str, int]: """Get total number of tiles per group.""" no_tiles = {} for pos, s in self.positions.items(): with h5py.File(s.filename, "r") as f: no_tiles[pos] = f["/trap_info/trap_locations"].shape[0] no_tiles_by_group = {k: 0 for k in self.groups} for posname, vals in no_tiles.items(): no_tiles_by_group[self.positions_groups[posname]] += vals return no_tiles_by_group @property def groups(self) -> t.Tuple[str]: """Get groups, sorted alphabetically, as a list.""" return list(sorted(set(self.positions_groups.values()))) def concat_one_signal( path: str, position: Signal, group: str, mode: str = "retained", position_name=None, tmax_in_mins_dict=None, cutoff: float = 0, **kwargs, ) -> pd.DataFrame: """ Retrieve a signal for one position. kwargs passed to signal.get_raw. """ if tmax_in_mins_dict and position_name in tmax_in_mins_dict: tmax_in_mins = tmax_in_mins_dict[position_name] else: tmax_in_mins = None if position_name is None: # name of h5 file position_name = position.stem if tmax_in_mins: print( f" Loading {path} for {position_name} up to time {tmax_in_mins}." ) else: print(f" Loading {path} for {position_name}.") if mode == "retained": # applies picking and merging via Signal.get combined = position.retained( path, tmax_in_mins=tmax_in_mins, cutoff=cutoff ) elif mode == "raw": # no picking and merging combined = position.get_raw(path, tmax_in_mins=tmax_in_mins, **kwargs) elif mode == "raw_daughters": combined = position.get_raw( path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs ) if combined is not None: combined = combined.loc[ combined.index.get_level_values("mother_label") > 0 ] elif mode == "raw_mothers": combined = position.get_raw( path, lineage=True, tmax_in_mins=tmax_in_mins, **kwargs ) if combined is not None: combined = combined.loc[ combined.index.get_level_values("mother_label") == 0 ] combined = combined.droplevel("mother_label") elif mode == "families": # applies picking and merging combined = position.get(path, tmax_in_mins=tmax_in_mins) else: raise Exception(f"concat_one_signal: {mode} not recognised.") if combined is not None: # add position and group as indices combined["position"] = position_name combined["group"] = group combined.set_index(["group", "position"], inplace=True, append=True) combined.index = combined.index.swaplevel(-2, 0).swaplevel(-1, 1) return combined