diff --git a/src/agora/abc.py b/src/agora/abc.py index bf54d8a23baad3fd4b742ebc24676c0c711dff3b..6bdb93b854eb7062e360da042b3cc686a0d88f82 100644 --- a/src/agora/abc.py +++ b/src/agora/abc.py @@ -17,16 +17,14 @@ atomic = t.Union[int, float, str, bool] class ParametersABC(ABC): """ - Defines parameters as attributes and allows parameters to + Define parameters as attributes and allow parameters to be converted to either a dictionary or to yaml. No attribute should be called "parameters"! """ def __init__(self, **kwargs): - """ - Defines parameters as attributes - """ + """Define parameters as attributes.""" assert ( "parameters" not in kwargs ), "No attribute should be named parameters" @@ -35,8 +33,9 @@ class ParametersABC(ABC): def to_dict(self, iterable="null") -> t.Dict: """ - Recursive function to return a nested dictionary of the - attributes of the class instance. + Return a nested dictionary of the attributes of the class instance. + + Uses recursion. """ if isinstance(iterable, dict): if any( @@ -62,7 +61,8 @@ class ParametersABC(ABC): def to_yaml(self, path: Union[Path, str] = None): """ - Returns a yaml stream of the attributes of the class instance. + Return a yaml stream of the attributes of the class instance. + If path is provided, the yaml stream is saved there. Parameters @@ -81,9 +81,7 @@ class ParametersABC(ABC): @classmethod def from_yaml(cls, source: Union[Path, str]): - """ - Returns instance from a yaml filename or stdin - """ + """Return instance from a yaml filename or stdin.""" is_buffer = True try: if Path(source).exists(): @@ -107,7 +105,8 @@ class ParametersABC(ABC): def update(self, name: str, new_value): """ - Update values recursively + Update values recursively. + if name is a dictionary, replace data where existing found or add if not. It warns against type changes. @@ -179,7 +178,8 @@ def add_to_collection( class ProcessABC(ABC): """ Base class for processes. - Defines parameters as attributes and requires run method to be defined. + + Define parameters as attributes and requires a run method. """ def __init__(self, parameters): @@ -243,11 +243,9 @@ class StepABC(ProcessABC): @timer def run_tp(self, tp: int, **kwargs): - """ - Time and log the timing of a step. - """ + """Time and log the timing of a step.""" return self._run_tp(tp, **kwargs) def run(self): # Replace run with run_tp - raise Warning("Steps use run_tp instead of run") + raise Warning("Steps use run_tp instead of run.") diff --git a/src/agora/io/cells.py b/src/agora/io/cells.py index 8dd1389eb93c200b482ffe2841e3871ed78154c2..9735d4f8cc4601c60c06e6227d4f3f241e717877 100644 --- a/src/agora/io/cells.py +++ b/src/agora/io/cells.py @@ -14,18 +14,23 @@ from utils_find_1st import cmp_equal, find_1st class Cells: """ - Extracts information from an h5 file. This class accesses: + Extract information from an h5 file. + + This class accesses in the h5 file: 'cell_info', which contains 'angles', 'cell_label', 'centres', 'edgemasks', 'ellipse_dims', 'mother_assign', 'mother_assign_dynamic', 'radii', 'timepoint', 'trap'. All of these except for 'edgemasks' are a 1D ndarray. + and + 'trap_info', which contains 'drifts', 'trap_locations' """ def __init__(self, filename, path="cell_info"): + """Initialise from a filename.""" self.filename: t.Optional[t.Union[str, Path]] = filename self.cinfo_path: t.Optional[str] = path self._edgemasks: t.Optional[str] = None @@ -36,7 +41,7 @@ class Cells: return cls(Path(source)) def _log(self, message: str, level: str = "warn"): - # Log messages in the corresponding level + """Log messages in the corresponding level.""" logger = logging.getLogger("aliby") getattr(logger, level)(f"{self.__class__.__name__}: {message}") @@ -48,42 +53,46 @@ class Cells: @staticmethod def _astype(array: np.ndarray, kind: str): - # Convert sparse arrays if needed and if kind is 'mask' it fills the outline + """Convert sparse arrays if needed; if kind is 'mask' fill the outline.""" array = Cells._asdense(array) if kind == "mask": array = ndimage.binary_fill_holes(array).astype(bool) return array def _get_idx(self, cell_id: int, trap_id: int): - # returns boolean array of time points where both the cell with cell_id and the trap with trap_id exist + """Return boolean array of time points where both the cell with cell_id and the trap with trap_id exist.""" return (self["cell_label"] == cell_id) & (self["trap"] == trap_id) @property def max_labels(self) -> t.List[int]: + """Return the maximum cell label per tile.""" return [max((0, *self.labels_in_trap(i))) for i in range(self.ntraps)] @property def max_label(self) -> int: + """Return the highest maximum cell label per tile.""" return sum(self.max_labels) @property def ntraps(self) -> int: - # find the number of traps from the h5 file + """Find the number of tiles, or traps.""" with h5py.File(self.filename, mode="r") as f: return len(f["trap_info/trap_locations"][()]) @property def tinterval(self): + """Return time interval in seconds.""" with h5py.File(self.filename, mode="r") as f: return f.attrs["time_settings/timeinterval"] @property def traps(self) -> t.List[int]: - # returns a list of traps + """List tile, or trap, IDs.""" return list(set(self["trap"])) @property def tile_size(self) -> t.Union[int, t.Tuple[int], None]: + """Give the x- and y- sizes of a tile.""" if self._tile_size is None: with h5py.File(self.filename, mode="r") as f: # self._tile_size = f["trap_info/tile_size"][0] @@ -91,12 +100,12 @@ class Cells: return self._tile_size def nonempty_tp_in_trap(self, trap_id: int) -> set: - # given a trap_id returns time points in which cells are available + """Given a tile_id, return time points for which cells are available.""" return set(self["timepoint"][self["trap"] == trap_id]) @property def edgemasks(self) -> t.List[np.ndarray]: - # returns the masks per tile + """Return a 3D array of masks for every cell.""" if self._edgemasks is None: edgem_path: str = "edgemasks" self._edgemasks = self._fetch(edgem_path) @@ -105,14 +114,15 @@ class Cells: @property def labels(self) -> t.List[t.List[int]]: """ - Return all cell labels in object - We use mother_assign to list traps because it is the only property that appears even - when no cells are found + Return all cell labels per tile. + + We use mother_assign to list tiles because it is the only property + that appears even when no cells are found. """ return [self.labels_in_trap(trap) for trap in range(self.ntraps)] def max_labels_in_frame(self, frame: int) -> t.List[int]: - # Return the maximum label for each trap in the given frame + """Get the maximal cell label for each tile.""" max_labels = [ self["cell_label"][ (self["timepoint"] <= frame) & (self["trap"] == trap_id) @@ -125,16 +135,16 @@ class Cells: """ Parameters ---------- - cell_id: int - Cell index - trap_id: int - Trap index + cell_id: int + Cell index + trap_id: int + Trap index Returns ---------- - indices int array - boolean mask array - edge_ix int array + indices: int array + boolean mask array + edge_ix int array """ indices = self._get_idx(cell_id, trap_id) edgem_ix = self._edgem_where(cell_id, trap_id) @@ -146,7 +156,7 @@ class Cells: def mask(self, cell_id, trap_id): """ - Returns the times and the binary masks of a given cell in a given tile. + Return the times and the binary masks of a given cell in a given tile. Parameters ---------- @@ -170,7 +180,7 @@ class Cells: self, timepoint: t.Iterable[int], kind="mask" ) -> t.List[t.List[np.ndarray]]: """ - Returns a list of lists of binary masks in a given list of time points. + Return a list of lists of binary masks in a given list of time points. Parameters ---------- @@ -183,9 +193,7 @@ class Cells: ------- List[List[np.ndarray]] A list of lists with binary masks grouped by tile IDs. - """ - ix = self["timepoint"] == timepoint traps = self["trap"][ix] edgemasks = self._edgem_from_masking(ix) @@ -200,7 +208,7 @@ class Cells: self, timepoints: t.Iterable[int], kind="mask" ) -> t.List[t.List[np.ndarray]]: """ - Returns a list of lists of binary masks for a given list of time points. + Return a list of lists of binary masks for a given list of time points. Parameters ---------- @@ -226,28 +234,25 @@ class Cells: def group_by_traps( self, traps: t.Collection, cell_labels: t.Collection ) -> t.Dict[int, t.List[int]]: - """ - Returns a dict with traps as keys and list of labels as value. - Note that the total number of traps are calculated from Cells.traps. - - """ + """Return a dict with traps as keys and a list of labels as value.""" iterator = groupby(zip(traps, cell_labels), lambda x: x[0]) d = {key: [x[1] for x in group] for key, group in iterator} d = {i: d.get(i, []) for i in self.traps} return d def labels_in_trap(self, trap_id: int) -> t.Set[int]: - # return set of cell ids for a given trap + """Return set of cell ids for a given trap.""" return set((self["cell_label"][self["trap"] == trap_id])) def labels_at_time(self, timepoint: int) -> t.Dict[int, t.List[int]]: + """Return cell labels for each tile at the specified time point.""" labels = self["cell_label"][self["timepoint"] == timepoint] traps = self["trap"][self["timepoint"] == timepoint] return self.group_by_traps(traps, labels) def __getitem__(self, item): + """Define and return item as a underscored attribute.""" assert item != "edgemasks", "Edgemasks must not be loaded as a whole" - _item = "_" + item if not hasattr(self, _item): setattr(self, _item, self._fetch(item)) @@ -269,23 +274,23 @@ class Cells: return edgem def outline(self, cell_id: int, trap_id: int): + """Get times and masks for when cell_id is in trap_id.""" id_mask = self._get_idx(cell_id, trap_id) times = self["timepoint"][id_mask] - return times, self._edgem_from_masking(id_mask) @property def ntimepoints(self) -> int: + """Total number of time points in the experiment.""" return self["timepoint"].max() + 1 @cached_property def _cells_vs_tps(self): - # Binary matrix showing the presence of all cells in all time points + """Binary matrix showing all cells in all time points.""" ncells_per_tile = [len(x) for x in self.labels] cells_vs_tps = np.zeros( (sum(ncells_per_tile), self.ntimepoints), dtype=bool ) - cells_vs_tps[ self._cell_cumsum[self["trap"]] + self["cell_label"] - 1, self["timepoint"], @@ -294,17 +299,16 @@ class Cells: @cached_property def _cell_cumsum(self): - # Cumulative sum indicating the number of cells per tile + """Cumulative sum indicating the number of cells per tile.""" ncells_per_tile = [len(x) for x in self.labels] cumsum = np.roll(np.cumsum(ncells_per_tile), shift=1) cumsum[0] = 0 return cumsum def _flat_index_to_tuple_location(self, idx: int) -> t.Tuple[int, int]: - # Convert a cell index to a tuple - # Note that it assumes tiles and cell labels are flattened, but - # it is agnostic to tps - + """Convert a cell index to a tuple.""" + # Note that we assumes tiles and cell labels are flattened, but + # are agnostic to tps. tile_id = int(np.where(idx + 1 > self._cell_cumsum)[0][-1]) cell_label = idx - self._cell_cumsum[tile_id] + 1 return tile_id, cell_label @@ -328,17 +332,13 @@ class Cells: window = sliding_window_view( self._cells_vs_tps, min_consecutive_tps, axis=1 ) - tp_min = window.sum(axis=-1) == min_consecutive_tps - - # Apply an interval filter to focucs on a slice + # apply an interval filter to focus on a slice if interval is not None: interval = tuple(np.array(interval)) else: interval = (0, window.shape[1]) - low_boundary, high_boundary = interval - tp_min[:, :low_boundary] = False tp_min[:, high_boundary:] = False return tp_min @@ -349,9 +349,7 @@ class Cells: @cached_property def mothers(self): - """ - Return nested list with final prediction of mother id for each cell - """ + """Return nested list with final prediction of mother id for each cell in each tile.""" return self.mother_assign_from_dynamic( self["mother_assign_dynamic"], self["cell_label"], @@ -362,20 +360,17 @@ class Cells: @cached_property def mothers_daughters(self) -> np.ndarray: """ - Return a single array with three columns, containing information about - the mother-daughter relationships: tile, mothers and daughters. + Return mother-daughter relationships for all tiles. Returns ------- - np.ndarray + mothers_daughters: np.ndarray An array with shape (n, 3) where n is the number of mother-daughter pairs found. - The columns contain: - - tile: the tile where the mother cell is located. - - mothers: the index of the mother cell within the tile. - - daughters: the index of the daughter cell within the tile. + The first column is the tile_id for the tile where the mother cell is located. + The second column is the cell index of a mother cell in the tile. + The third column is the index of the corresponding daughter cell. """ nested_massign = self.mothers - if sum([x for y in nested_massign for x in y]): mothers_daughters = np.array( [ @@ -389,46 +384,45 @@ class Cells: else: mothers_daughters = np.array([]) self._log("No mother-daughters assigned") - return mothers_daughters @staticmethod def mother_assign_to_mb_matrix(ma: t.List[np.array]): """ - Convert from a list of lists of mother-bud paired assignments to a - sparse matrix with a boolean dtype. The rows correspond to - to daughter buds. The values are boolean and indicate whether a - given cell is a mother cell and a given daughter bud is assigned - to the mother cell in the next timepoint. + Convert a list of mother-daughter lists into a boolean sparse matrix. + + Each row in the matrix correspond to daughter buds. + If an entry is True, a given cell is a mother cell and a given + daughter bud is assigned to the mother cell in the next timepoint. Parameters: ----------- ma : list of lists of integers - A list of lists of mother-bud assignments. The i-th sublist contains the - bud assignments for the i-th tile. The integers in each sublist - represent the mother label, if it is zero no mother was found. + A list of lists of mother-bud assignments. + The i-th sublist contains the bud assignments for the i-th tile. + The integers in each sublist represent the mother label, with zero + implying no mother found. Returns: -------- mb_matrix : boolean numpy array of shape (n, m) - An n x m boolean numpy array where n is the total number of cells (sum - of the lengths of all sublists in ma) and m is the maximum number of buds - assigned to any mother cell in ma. The value at (i, j) is True if cell i - is a daughter cell and cell j is its mother assigned to i. + An n x m array where n is the total number of cells (sum + of the lengths of all sublists in ma) and m is the maximum + number of buds assigned to any mother cell in ma. + The value at (i, j) is True if cell i is a daughter cell and + cell j is its assigned mother. Examples: -------- - ma = [[0, 0, 1], [0, 1, 0]] - Cells(None).mother_assign_to_mb_matrix(ma) - # array([[False, False, False, False, False, False], - # [False, False, False, False, False, False], - # [ True, False, False, False, False, False], - # [False, False, False, False, False, False], - # [False, False, False, True, False, False], - # [False, False, False, False, False, False]]) - + >>> ma = [[0, 0, 1], [0, 1, 0]] + >>> Cells(None).mother_assign_to_mb_matrix(ma) + >>> array([[False, False, False, False, False, False], + [False, False, False, False, False, False], + [ True, False, False, False, False, False], + [False, False, False, False, False, False], + [False, False, False, True, False, False], + [False, False, False, False, False, False]]) """ - ncells = sum([len(t) for t in ma]) mb_matrix = np.zeros((ncells, ncells), dtype=bool) c = 0 @@ -436,9 +430,7 @@ class Cells: for d, m in enumerate(cells): if m: mb_matrix[c + d, c + m - 1] = True - c += len(cells) - return mb_matrix @staticmethod @@ -466,7 +458,6 @@ class Cells: """ idlist = list(zip(trap, cell_label)) cell_gid = np.unique(idlist, axis=0) - last_lin_preds = [ find_1st( ((cell_label[::-1] == lbl) & (trap[::-1] == tr)), @@ -476,12 +467,10 @@ class Cells: for tr, lbl in cell_gid ] mother_assign_sorted = ma[::-1][last_lin_preds] - traps = cell_gid[:, 0] iterator = groupby(zip(traps, mother_assign_sorted), lambda x: x[0]) d = {key: [x[1] for x in group] for key, group in iterator} nested_massign = [d.get(i, []) for i in range(ntraps)] - return nested_massign @lru_cache(maxsize=200) @@ -489,13 +478,12 @@ class Cells: self, frame: int, global_id: bool = False ) -> np.ndarray: """ - Returns labels in a 4D ndarray with the global ids with shape - (ntraps, max_nlabels, ysize, xsize) at a given frame. + Return labels in a 4D ndarray with potentially global ids. Parameters ---------- frame : int - The frame number. + The frame number (time point). global_id : bool, optional If True, the returned array contains global ids, otherwise it contains only the local ids of the labels. Default is False. @@ -511,7 +499,6 @@ class Cells: Notes ----- This method uses lru_cache to cache the results for faster access. - """ labels_in_frame = self.labels_at_time(frame) n_labels = [ @@ -552,7 +539,7 @@ class Cells: self, frame: int, tile_shape: t.Tuple[int] ) -> t.List[np.ndarray]: """ - Returns a list of stacked masks, each corresponding to a tile at a given timepoint. + Return a list of stacked masks, each corresponding to a tile at a given time point. Parameters ---------- @@ -582,7 +569,7 @@ class Cells: interval=None, ) -> t.Tuple[np.ndarray, np.ndarray]: """ - Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive timepoints. + Sample tiles that have a minimum number of cells and are occupied for at least a minimum number of consecutive time points. Parameters ---------- @@ -591,7 +578,7 @@ class Cells: min_ncells: int, optional (default=2) The minimum number of cells per tile. min_consecutive_ntps: int, optional (default=5) - The minimum number of consecutive timepoints a cell must be present in a trap. + The minimum number of consecutive timep oints a cell must be present in a trap. seed: int, optional (default=0) Random seed value for reproducibility. interval: None or Tuple(int,int), optional (default=None) @@ -612,16 +599,12 @@ class Cells: min_consecutive_tps=min_consecutive_ntps, interval=interval, ) - # Find all valid tiles with min_ncells for at least min_tps index_id, tps = np.where(cell_availability_matrix) - if interval is None: # Limit search interval = (0, cell_availability_matrix.shape[1]) - np.random.seed(seed) choices = np.random.randint(len(index_id), size=size) - linear_indices = np.zeros_like(self["cell_label"], dtype=bool) for cell_index_flat, tp in zip(index_id[choices], tps[choices]): tile_id, cell_label = self._flat_index_to_tuple_location( @@ -634,7 +617,6 @@ class Cells: & (self["timepoint"] == tp) ) ] = True - return linear_indices def _sample_masks( @@ -674,25 +656,22 @@ class Cells: seed=seed, interval=interval, ) - # Sort sampled tiles to use automatic cache when possible tile_ids = self["trap"][sampled_bitmask] cell_labels = self["cell_label"][sampled_bitmask] tps = self["timepoint"][sampled_bitmask] - masks = [] for tile_id, cell_label, tp in zip(tile_ids, cell_labels, tps): local_idx = self.labels_at_time(tp)[tile_id].index(cell_label) tile_mask = self.at_time(tp)[tile_id][local_idx] masks.append(tile_mask) - return (tile_ids, cell_labels, tps), np.stack(masks) def matrix_trap_tp_where( self, min_ncells: int = 2, min_consecutive_tps: int = 5 ): """ - NOTE CURRENLTY UNUSED WITHIN ALIBY THE MOMENT. MAY BE USEFUL IN THE FUTURE. + NOTE CURRENTLY UNUSED BUT USEFUL. Return a matrix of shape (ntraps x ntps - min_consecutive_tps) to indicate traps and time-points where min_ncells are available for at least min_consecutive_tps @@ -708,7 +687,6 @@ class Cells: (ntraps x ( ntps-min_consecutive_tps )) 2D boolean numpy array where rows are trap ids and columns are timepoint windows. If the value in a cell is true its corresponding trap and timepoint contains more than min_ncells for at least min_consecutive time-points. """ - window = sliding_window_view( self._tiles_vs_cells_vs_tps, min_consecutive_tps, axis=2 ) @@ -720,7 +698,7 @@ class Cells: def stack_masks_in_tile( masks: t.List[np.ndarray], tile_shape: t.Tuple[int] ) -> np.ndarray: - # Stack all masks in a trap padding accordingly if no outlines found + """Stack all masks in a trap, padding accordingly if no outlines found.""" result = np.zeros((0, *tile_shape), dtype=bool) if len(masks): result = np.stack(masks) diff --git a/src/agora/io/metadata.py b/src/agora/io/metadata.py index 211d849620809639968fd37908e9b02ca736cc70..56f70b548c3fb7440c65e97061500927a05c489d 100644 --- a/src/agora/io/metadata.py +++ b/src/agora/io/metadata.py @@ -66,7 +66,7 @@ class MetaData: # Needed because HDF5 attributes do not support dictionaries def flatten_dict(nested_dict, separator="/"): """ - Flattens nested dictionary. If empty return as-is. + Flatten nested dictionary. If empty return as-is. """ flattened = {} if nested_dict: @@ -79,9 +79,7 @@ def flatten_dict(nested_dict, separator="/"): # Needed because HDF5 attributes do not support datetime objects # Takes care of time zones & daylight saving def datetime_to_timestamp(time, locale="Europe/London"): - """ - Convert datetime object to UNIX timestamp - """ + """Convert datetime object to UNIX timestamp.""" return timezone(locale).localize(time).timestamp() @@ -189,36 +187,37 @@ def parse_swainlab_metadata(filedir: t.Union[str, Path]): Dictionary with minimal metadata """ filedir = Path(filedir) - filepath = find_file(filedir, "*.log") if filepath: + # new log files raw_parse = parse_from_swainlab_grammar(filepath) minimal_meta = get_meta_swainlab(raw_parse) else: + # old log files if filedir.is_file() or str(filedir).endswith(".zarr"): + # log file is in parent directory filedir = filedir.parent legacy_parse = parse_logfiles(filedir) minimal_meta = ( get_meta_from_legacy(legacy_parse) if legacy_parse else {} ) - return minimal_meta def dispatch_metadata_parser(filepath: t.Union[str, Path]): """ - Function to dispatch different metadata parsers that convert logfiles into a - basic metadata dictionary. Currently only contains the swainlab log parsers. + Dispatch different metadata parsers that convert logfiles into a dictionary. + + Currently only contains the swainlab log parsers. Input: -------- - filepath: str existing file containing metadata, or folder containing naming conventions + filepath: str existing file containing metadata, or folder containing naming + conventions """ parsed_meta = parse_swainlab_metadata(filepath) - if parsed_meta is None: parsed_meta = dir_to_meta - return parsed_meta diff --git a/src/agora/io/signal.py b/src/agora/io/signal.py index ae6ae71ed2ed4da9190ce2959b7f0484a97672ae..392b6c7b0db806e05f3a737887cb48add50d570f 100644 --- a/src/agora/io/signal.py +++ b/src/agora/io/signal.py @@ -60,10 +60,10 @@ class Signal(BridgeH5): def get(self, dsets: t.Union[str, t.Collection], **kwargs): """Get and potentially pre-process data from h5 file and return as a dataframe.""" - if isinstance(dsets, str): # no pre-processing - df = self.get_raw(dsets, **kwargs) + if isinstance(dsets, str): + # no pre-processing + dsets = self.get_raw(dsets, **kwargs) prepost_applied = self.apply_prepost(dsets, **kwargs) - return self.add_name(prepost_applied, dsets) @staticmethod @@ -73,7 +73,7 @@ class Signal(BridgeH5): return df def cols_in_mins(self, df: pd.DataFrame): - # Convert numerical columns in a dataframe to minutes + """Convert numerical columns in a dataframe to minutes.""" try: df.columns = (df.columns * self.tinterval // 60).astype(int) except Exception as e: @@ -141,7 +141,6 @@ class Signal(BridgeH5): if lineage_location not in f: lineage_location = "postprocessing/lineage" tile_mo_da = f[lineage_location] - if isinstance(tile_mo_da, h5py.Dataset): lineage = tile_mo_da[()] else: @@ -272,7 +271,7 @@ class Signal(BridgeH5): Parameters ---------- dataset: str or list of strs - The name of the h5 file or a list of h5 file names + The name of the h5 file or a list of h5 file names. in_minutes: boolean If True, lineage: boolean @@ -288,15 +287,17 @@ class Signal(BridgeH5): self.get_raw(dset, in_minutes=in_minutes, lineage=lineage) for dset in dataset ] - if lineage: # assume that df is sorted + if lineage: + # assume that df is sorted mother_label = np.zeros(len(df), dtype=int) lineage = self.lineage() - a, b = validate_association( + valid_association, valid_indices = validate_association( lineage, np.array(df.index.to_list()), + #: are mothers not match_column=0? match_column=1, ) - mother_label[b] = lineage[a, 1] + mother_label[valid_indices] = lineage[valid_association, 1] df = add_index_levels(df, {"mother_label": mother_label}) return df except Exception as e: @@ -353,10 +354,7 @@ class Signal(BridgeH5): fullname: str, node: t.Union[h5py.Dataset, h5py.Group], ): - """ - Store the name of a signal if it is a leaf node - (a group with no more groups inside) and if it starts with extraction. - """ + """Store the name of a signal if it is a leaf node and if it starts with extraction.""" if isinstance(node, h5py.Group) and np.all( [isinstance(x, h5py.Dataset) for x in node.values()] ): diff --git a/src/agora/utils/indexing_new.py b/src/agora/utils/indexing_new.py new file mode 100644 index 0000000000000000000000000000000000000000..7c749088c3276d95ccb9c54b8d6b7b99ea6b4fc9 --- /dev/null +++ b/src/agora/utils/indexing_new.py @@ -0,0 +1,195 @@ +#!/usr/bin/env jupyter +""" +Utilities based on association are used to efficiently acquire indices of +tracklets with some kind of relationship. +This can be: + - Cells that are to be merged. + - Cells that have a lineage relationship. +""" + +import numpy as np +import typing as t + + +def validate_association_new( + association: np.ndarray, + indices: np.ndarray, + match_column: t.Optional[int] = None, +) -> t.Tuple[np.ndarray, np.ndarray]: + """ + Identify matches between two arrays by comparing rows. + + We match lineage data on mother-bud pairs with all the cells identified to specialise to only those cells in mother-bud pairs. + + We use broadcasting for speed. + + Both a mother and bud in association must be in indices. + + Parameters + ---------- + association : np.ndarray + 2D array of lineage associations where columns are (trap, mother, daughter) + or + a 3D array, which is an array of 2 X 2 arrays comprising [[trap_id, mother_label], [trap_id, daughter_label]]. + indices : np.ndarray + A 2D array where each column is a different level, such as (trap_id, cell_label), which typically is an index of a Signal + dataframe. This array should not include mother_label. + match_column: int + If 0, matches indicate mothers from mother-bud pairs; + If 1, matches indicate daughters from mother-bud pairs; + If None, matches indicate either mothers or daughters in mother-bud pairs. + + Returns + ------- + valid_association: boolean np.ndarray + 1D array indicating elements in association with matches. + valid_indices: boolean np.ndarray + 1D array indicating elements in indices with matches. + + Examples + -------- + >>> import numpy as np + >>> from agora.utils.indexing import validate_association + + >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ]) + >>> indices = np.array([ [0, 1], [0, 2], [0, 3]]) + >>> print(indices.T) + + >>> valid_association, valid_indices = validate_association(association, indices) + + >>> print(valid_association) + array([ True, False, False, False]) + >>> print(valid_indices) + array([ True, False, True]) + + and + + >>> association = np.array([[[0,3], [0,1]], [[0,2], [0,4]]]) + >>> indices = np.array([[0,1], [0,2], [0,3]]) + >>> valid_association, valid_indices = validate_association(association, indices) + >>> print(valid_association) + array([ True, False]) + >>> print(valid_indices) + array([ True, False, True]) + """ + if association.ndim == 2: + # reshape into 3D array for broadcasting + # for each trap, [trap, mother, daughter] becomes + # [[trap, mother], [trap, daughter]] + association = _assoc_indices_to_3d(association) + # use broadcasting to compare association with indices + # swap trap and cell_label axes for correct broadcasting + indicesT = indices.T + # compare each of [[trap, mother], [trap, daughter]] for all traps + # in association with [trap, cell_label] for all traps in indices + valid_ndassociation = ( + association[..., np.newaxis] == indicesT[np.newaxis, ...] + ) + # find matches in association + ### + # make True comparisons have both trap_ids and cell labels matching + valid_cell_ids = valid_ndassociation.all(axis=2) + if match_column is None: + # make True comparisons match at least one row in indices + va_intermediate = valid_cell_ids.any(axis=2) + # make True comparisons have both mother and bud matching rows in indices + valid_association = va_intermediate.all(axis=1) + else: + # match_column selects mothers if 0 and daughters if 1 + # make True match at least one row in indices + valid_association = valid_cell_ids[:, match_column].any(axis=1) + # find matches in indices + ### + # make True comparisons have a validated association for both the mother and bud + # make True comparisons have both trap_ids and cell labels matching + valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2) + if match_column is None: + # make True comparisons match either a mother or a bud in association + valid_indices = valid_cell_ids_va.any(axis=1)[0] + else: + valid_indices = valid_cell_ids_va[:, match_column][0] + + # Alan's working code + # Compare existing association with available indices + # Swap trap and label axes for the association array to correctly cast + valid_ndassociation_a = association[..., None] == indices.T[None, ...] + # Broadcasting is confusing (but efficient): + # First we check the dimension across trap and cell id, to ensure both match + valid_cell_ids_a = valid_ndassociation_a.all(axis=2) + if match_column is None: + # Then we check the merge tuples to check which cases have both target and source + valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1) + + # Finally we check the dimension that crosses all indices, to ensure the pair + # is present in a valid merge event. + valid_indices_a = ( + valid_ndassociation_a[valid_association_a] + .all(axis=2) + .any(axis=(0, 1)) + ) + else: # We fetch specific indices if we aim for the ones with one present + valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0) + # Valid association then becomes a boolean array, true means that there is a + # match (match_column) between that cell and the index + valid_association_a = ( + valid_cell_ids_a[:, match_column] & valid_indices + ).any(axis=1) + + assert valid_association != valid_association_a, "valid_association error" + assert valid_indices != valid_indices_a, "valid_indices error" + + return valid_association, valid_indices + + +def _assoc_indices_to_3d(ndarray: np.ndarray): + """ + Reorganise an array of shape (N, 3) into one of shape (N, 2, 2). + + Reorganise an array so that the last entry of each row is removed + and generates a new row. This new row retains all other entries of + the original row. + + Example: + [ [0, 1, 3], [0, 1, 4] ] + becomes + [ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ] + """ + result = ndarray + if len(ndarray) and ndarray.ndim > 1: + if ndarray.shape[1] == 3: + # faster indexing for single positions + result = np.transpose( + np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2), + axes=[0, 2, 1], + ) + else: + # 20% slower, but more general indexing + columns = np.arange(ndarray.shape[1]) + result = np.stack( + ( + ndarray[:, np.delete(columns, -1)], + ndarray[:, np.delete(columns, -2)], + ), + axis=1, + ) + return result + + +def _3d_index_to_2d(array: np.ndarray): + """Revert switch from _assoc_indices_to_3d.""" + result = array + if len(array): + result = np.concatenate( + (array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1 + ) + return result + + +def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Compare two 2D arrays using broadcasting. + + Return a binary array where a True value links two cells where + all cells are the same. + """ + return (x[..., np.newaxis] == y.T[np.newaxis, ...]).all(axis=1) diff --git a/src/agora/utils/kymograph.py b/src/agora/utils/kymograph.py index 276bb3cd207222626c60483e9932bdb4fd022e60..62a5a96238db57bae592e85cee3a164b0f47338e 100644 --- a/src/agora/utils/kymograph.py +++ b/src/agora/utils/kymograph.py @@ -170,6 +170,7 @@ def slices_from_spans(spans: t.Tuple[int], df: pd.DataFrame) -> t.List[slice]: def drop_mother_label(index: pd.MultiIndex) -> np.ndarray: + """Remove mother_label level from a MultiIndex.""" no_mother_label = index if "mother_label" in index.names: no_mother_label = index.droplevel("mother_label") diff --git a/src/agora/utils/merge.py b/src/agora/utils/merge.py index 936699122e85cd84dbf58e3205e0e8a0c7dbb5f2..442bdf0e399e07116713355858680241628296bb 100644 --- a/src/agora/utils/merge.py +++ b/src/agora/utils/merge.py @@ -13,8 +13,11 @@ from agora.utils.indexing import compare_indices, validate_association def apply_merges(data: pd.DataFrame, merges: np.ndarray): - """Split data in two, one subset for rows relevant for merging and one - without them. It uses an array of source tracklets and target tracklets + """ + Split data in two, one subset for rows relevant for merging and one + without them. + + Use an array of source tracklets and target tracklets to efficiently merge them. Parameters @@ -43,7 +46,7 @@ def apply_merges(data: pd.DataFrame, merges: np.ndarray): # Implement the merges and drop source rows. # TODO Use matrices to perform merges in batch - # for ecficiency + # for efficiency if valid_merges.any(): to_merge = data.loc[indices] targets, sources = zip(*merges[valid_merges]) diff --git a/src/aliby/io/dataset.py b/src/aliby/io/dataset.py index 30f2cd6f490db4ed24048f4017d518fe252297f4..28ba59d1b9d690a073785d145b3ad4dc27535eca 100644 --- a/src/aliby/io/dataset.py +++ b/src/aliby/io/dataset.py @@ -54,7 +54,7 @@ class DatasetLocalABC(ABC): Abstract Base class to find local files, either OME-XML or raw images. """ - _valid_suffixes = ("tiff", "png", "zarr") + _valid_suffixes = ("tiff", "png", "zarr", "tif") _valid_meta_suffixes = ("txt", "log") def __init__(self, dpath: t.Union[str, Path], *args, **kwargs): diff --git a/src/aliby/io/image.py b/src/aliby/io/image.py index 282c236140e3198d63fdf2833453e9b9fb540f3e..5af04ca679d92fe9d40ca33f029b9553f42a9dc2 100644 --- a/src/aliby/io/image.py +++ b/src/aliby/io/image.py @@ -30,14 +30,14 @@ from agora.io.metadata import dir_to_meta, dispatch_metadata_parser def get_examples_dir(): - """Get examples directory which stores dummy image for tiler""" + """Get examples directory that stores dummy image for tiler.""" return files("aliby").parent.parent / "examples" / "tiler" def instantiate_image( source: t.Union[str, int, t.Dict[str, str], Path], **kwargs ): - """Wrapper to instatiate the appropiate image + """Wrapper to instantiate the appropriate image Parameters ---------- @@ -55,26 +55,26 @@ def instantiate_image( def dispatch_image(source: t.Union[str, int, t.Dict[str, str], Path]): - """ - Wrapper to pick the appropiate Image class depending on the source of data. - """ + """Pick the appropriate Image class depending on the source of data.""" if isinstance(source, (int, np.int64)): from aliby.io.omero import Image - instatiator = Image + instantiator = Image elif isinstance(source, dict) or ( isinstance(source, (str, Path)) and Path(source).is_dir() ): if Path(source).suffix == ".zarr": - instatiator = ImageZarr + instantiator = ImageZarr else: - instatiator = ImageDir + instantiator = ImageDir + elif isinstance(source, Path) and source.is_file(): + # my addition + instantiator = ImageLocalOME elif isinstance(source, str) and Path(source).is_file(): - instatiator = ImageLocalOME + instantiator = ImageLocalOME else: raise Exception(f"Invalid data source at {source}") - - return instatiator + return instantiator class BaseLocalImage(ABC): @@ -82,6 +82,7 @@ class BaseLocalImage(ABC): Base Image class to set path and provide context management method. """ + # default image order _default_dimorder = "tczyx" def __init__(self, path: t.Union[str, Path]): @@ -98,8 +99,7 @@ class BaseLocalImage(ABC): return False def rechunk_data(self, img): - # Format image using x and y size from metadata. - + """Format image using x and y size from metadata.""" self._rechunked_img = da.rechunk( img, chunks=( @@ -145,16 +145,16 @@ class ImageLocalOME(BaseLocalImage): in which a multidimensional tiff image contains the metadata. """ - def __init__(self, path: str, dimorder=None): + def __init__(self, path: str, dimorder=None, **kwargs): super().__init__(path) self._id = str(path) + self.set_meta(str(path)) - def set_meta(self): + def set_meta(self, path): meta = dict() try: with TiffFile(path) as f: self._meta = xmltodict.parse(f.ome_metadata)["OME"] - for dim in self.dimorder: meta["size_" + dim.lower()] = int( self._meta["Image"]["Pixels"]["@Size" + dim] @@ -165,21 +165,19 @@ class ImageLocalOME(BaseLocalImage): ] meta["name"] = self._meta["Image"]["@Name"] meta["type"] = self._meta["Image"]["Pixels"]["@Type"] - - except Exception as e: # Images not in OMEXML - + except Exception as e: + # images not in OMEXML print("Warning:Metadata not found: {}".format(e)) print( - f"Warning: No dimensional info provided. Assuming {self._default_dimorder}" + "Warning: No dimensional info provided. " + f"Assuming {self._default_dimorder}" ) - - # Mark non-existent dimensions for padding + # mark non-existent dimensions for padding self.base = self._default_dimorder # self.ids = [self.index(i) for i in dimorder] - - self._dimorder = base - + self._dimorder = self.base self._meta = meta + # self._meta["name"] = Path(path).name.split(".")[0] @property def name(self): @@ -246,7 +244,7 @@ class ImageDir(BaseLocalImage): It inherits from BaseLocalImage so we only override methods that are critical. Assumptions: - - One folders per position. + - One folder per position. - Images are flat. - Channel, Time, z-stack and the others are determined by filenames. - Provides Dimorder as it is set in the filenames, or expects order during instatiation @@ -318,7 +316,7 @@ class ImageZarr(BaseLocalImage): print(f"Could not add size info to metadata: {e}") def get_data_lazy(self) -> da.Array: - """Return 5D dask array. For lazy-loading local multidimensional zarr files""" + """Return 5D dask array for lazy-loading local multidimensional zarr files.""" return self._img def add_size_to_meta(self): diff --git a/src/aliby/pipeline.py b/src/aliby/pipeline.py index dcb807cde242eb1415a5aa639ea64ee88cb1dd80..6e37bac203f9283223e633a3b8c5c157289a1694 100644 --- a/src/aliby/pipeline.py +++ b/src/aliby/pipeline.py @@ -154,6 +154,7 @@ class PipelineParameters(ParametersABC): defaults["tiler"]["backup_ref_channel"] = backup_ref_channel defaults["baby"] = BabyParameters.default(**baby).to_dict() + # why are BabyParameters here as an alternative? defaults["extraction"] = ( exparams_from_meta(meta_d) or BabyParameters.default(**extraction).to_dict() @@ -320,7 +321,7 @@ class Pipeline(ProcessABC): ) # get log files, either locally or via OMERO with dispatcher as conn: - image_ids = conn.get_images() + position_ids = conn.get_images() directory = self.store or root_dir / conn.unique_name if not directory.exists(): directory.mkdir(parents=True) @@ -330,29 +331,29 @@ class Pipeline(ProcessABC): self.parameters.general["directory"] = str(directory) config["general"]["directory"] = directory self.setLogger(directory) - # pick particular images if desired + # pick particular positions if desired if pos_filter is not None: if isinstance(pos_filter, list): - image_ids = { + position_ids = { k: v for filt in pos_filter - for k, v in self.apply_filter(image_ids, filt).items() + for k, v in self.apply_filter(position_ids, filt).items() } else: - image_ids = self.apply_filter(image_ids, pos_filter) - assert len(image_ids), "No images to segment" + position_ids = self.apply_filter(position_ids, pos_filter) + assert len(position_ids), "No images to segment" # create pipelines if distributed != 0: # multiple cores with Pool(distributed) as p: results = p.map( lambda x: self.run_one_position(*x), - [(k, i) for i, k in enumerate(image_ids.items())], + [(k, i) for i, k in enumerate(position_ids.items())], ) else: # single core results = [] - for k, v in tqdm(image_ids.items()): + for k, v in tqdm(position_ids.items()): r = self.run_one_position((k, v), 1) results.append(r) return results @@ -432,6 +433,7 @@ class Pipeline(ProcessABC): if process_from["extraction"] < tps: # TODO Move this parameter validation into Extractor av_channels = set((*steps["tiler"].channels, "general")) + # overwrite extraction specified by PipelineParameters !! config["extraction"]["tree"] = { k: v for k, v in config["extraction"]["tree"].items() @@ -453,13 +455,14 @@ class Pipeline(ProcessABC): steps["extraction"] = Extractor.from_tiler( exparams, store=filename, tiler=steps["tiler"] ) - # set up progress meter + # set up progress bar pbar = tqdm( range(min_process_from, tps), desc=image.name, initial=min_process_from, total=tps, ) + # run through time points for i in pbar: if ( frac_clogged_traps @@ -469,9 +472,12 @@ class Pipeline(ProcessABC): # run through steps for step in self.pipeline_steps: if i >= process_from[step]: + # perform step result = steps[step].run_tp( i, **run_kwargs.get(step, {}) ) + # write to h5 file using writers + # extractor writes to h5 itself if step in loaded_writers: loaded_writers[step].write( data=result, @@ -481,7 +487,7 @@ class Pipeline(ProcessABC): tp=i, meta={"last_processed": i}, ) - # perform step + # clean up if ( step == "tiler" and i == min_process_from @@ -501,7 +507,7 @@ class Pipeline(ProcessABC): tp=i, ) elif step == "extraction": - # remove mask/label after extraction + # remove masks and labels after extraction for k in ["masks", "labels"]: run_kwargs[step][k] = None # check and report clogging @@ -677,12 +683,12 @@ class Pipeline(ProcessABC): for i, step in enumerate(self.step_sequence, 1) } - # Set up + # set up directory = config["general"]["directory"] - trackers_state: t.List[np.ndarray] = [] with dispatch_image(image_id)(image_id, **self.server_info) as image: filename = Path(f"{directory}/{image.name}.h5") + # load metadata meta = MetaData(directory, filename) from_start = True if np.any(ow.values()) else False # remove existing file if overwriting diff --git a/src/aliby/tile/tiler.py b/src/aliby/tile/tiler.py index b0769e1d22ef6306162b18b49af696a7ef55cdb7..27c1e814d4682228149fc0343d9a1a564cc7ad14 100644 --- a/src/aliby/tile/tiler.py +++ b/src/aliby/tile/tiler.py @@ -1,17 +1,29 @@ """ Tiler: Divides images into smaller tiles. -The tasks of the Tiler are selecting regions of interest, or tiles, of images - with one trap per tile, correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and Aliby’s image-processing steps. +The tasks of the Tiler are selecting regions of interest, or tiles, of +images - with one trap per tile, correcting for the drift of the microscope +stage over time, and handling errors and bridging between the image data +and Aliby’s image-processing steps. Tiler subclasses deal with either network connections or local files. -To find tiles, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the tiles' centres. +To find tiles, we use a two-step process: we analyse the bright-field image +to produce the template of a trap, and we fit this template to the image to +find the tiles' centres. -We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps. +We use texture-based segmentation (entropy) to split the image into +foreground -- cells and traps -- and background, which we then identify with +an Otsu filter. Two methods are used to produce a template trap from these +regions: pick the trap with the smallest minor axis length and average over +all validated traps. -A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the approach to template that identifies the most tiles. +A peak-identifying algorithm recovers the x and y-axis location of traps in +the original image, and we choose the approach to template that identifies +the most tiles. -The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y). +The experiment is stored as an array with a standard indexing order of +(Time, Channels, Z-stack, X, Y). """ import logging import re @@ -355,9 +367,9 @@ class Tiler(StepABC): full: an array of images """ full = self.image[t, c] - if hasattr(full, "compute"): # If using dask fetch images here + if hasattr(full, "compute"): + # if using dask fetch images full = full.compute(scheduler="synchronous") - return full @property @@ -593,7 +605,10 @@ class Tiler(StepABC): def get_channel_index(self, channel: str or int) -> int or None: """ - Find index for channel using regex. Returns the first matched string. + Find index for channel using regex. + + Return the first matched string. + If self.channels is integers (no image metadata) it returns None. If channel is integer @@ -602,10 +617,8 @@ class Tiler(StepABC): channel: string or int The channel or index to be used. """ - if all(map(lambda x: isinstance(x, int), self.channels)): channel = channel if isinstance(channel, int) else None - if isinstance(channel, str): channel = find_channel_index(self.channels, channel) return channel diff --git a/src/extraction/core/extractor.py b/src/extraction/core/extractor.py index 34ce783cc0f10b951677e702e3ad38e6b1f75766..cebb0b2a1d9229174c18f1ea9a9f5cb0a80383cf 100644 --- a/src/extraction/core/extractor.py +++ b/src/extraction/core/extractor.py @@ -80,7 +80,7 @@ class Extractor(StepABC): Usually the metric is applied to only a tile's masked area, but some metrics depend on the whole tile. - Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third level. + Extraction follows a three-level tree structure. Channels, such as GFP, are the root level; the reduction algorithm, such as maximum projection, is the second level; the specific metric, or operation, to apply to the masks, such as mean, is the third or leaf level. """ # TODO Alan: Move this to a location with the SwainLab defaults @@ -202,7 +202,7 @@ class Extractor(StepABC): self._custom_funs[k] = tmp(f) def load_funs(self): - """Define all functions, including custum ones.""" + """Define all functions, including custom ones.""" self.load_custom_funs() self._all_cell_funs = set(self._custom_funs.keys()).union(CELL_FUNS) # merge the two dicts @@ -335,7 +335,7 @@ class Extractor(StepABC): **kwargs, ) -> t.Dict[str, t.Dict[reduction_method, t.Dict[str, pd.Series]]]: """ - Wrapper to apply reduction and then extraction. + Wrapper to reduce to a 2D image and then extract. Parameters ---------- @@ -499,7 +499,6 @@ class Extractor(StepABC): # calculate metrics with subtracted bg ch_bs = ch + "_bgsub" # subtract median background - self.img_bgsub[ch_bs] = np.moveaxis( np.stack( list( @@ -579,7 +578,9 @@ class Extractor(StepABC): **kwargs, ) -> dict: """ - Wrapper to add compatibility with other steps of the pipeline. + Run extraction for one position and for the specified time points. + + Save the results to a h5 file. Parameters ---------- @@ -597,7 +598,7 @@ class Extractor(StepABC): Returns ------- d: dict - A dict of the extracted data with a concatenated string of channel, reduction metric, and cell metric as keys and pd.Series of the extracted data as values. + A dict of the extracted data for one position with a concatenated string of channel, reduction metric, and cell metric as keys and pd.DataFrame of the extracted data for all time points as values. """ if tree is None: tree = self.params.tree @@ -633,7 +634,7 @@ class Extractor(StepABC): def save_to_hdf(self, dict_series, path=None): """ - Save the extracted data to the h5 file. + Save the extracted data for one position to the h5 file. Parameters ---------- diff --git a/src/extraction/core/functions/defaults.py b/src/extraction/core/functions/defaults.py index e159842bad00ac61b22c177e1215319c2d9b460a..d4741ca46495488ca31a7c0ddb46405d0e25972a 100644 --- a/src/extraction/core/functions/defaults.py +++ b/src/extraction/core/functions/defaults.py @@ -1,10 +1,9 @@ # File with defaults for ease of use -import re import typing as t from pathlib import Path + import h5py -# should we move these functions here? from aliby.tile.tiler import find_channel_name @@ -59,6 +58,7 @@ def exparams_from_meta( for ch in extant_fluorescence_ch: base["tree"][ch] = default_reduction_metrics base["sub_bg"] = extant_fluorescence_ch + # additional extraction defaults if the channels are available if "ph" in extras: # SWAINLAB specific names diff --git a/src/postprocessor/chainer.py b/src/postprocessor/chainer.py index b834fb5d831eeeff75f84d18e5c5eb02775652f4..027e02ee01bdc04d97c530d86512d19b630e6333 100644 --- a/src/postprocessor/chainer.py +++ b/src/postprocessor/chainer.py @@ -14,9 +14,10 @@ from postprocessor.core.abc import get_process class Chainer(Signal): """ Extend Signal by applying post-processes and allowing composite signals that combine basic signals. - It "chains" multiple processes upon fetching a dataset to produce the desired datasets. - Instead of reading processes previously applied, it executes + Chainer "chains" multiple processes upon fetching a dataset. + + Instead of reading processes previously applied, Chainer executes them when called. """ @@ -25,6 +26,7 @@ class Chainer(Signal): } def __init__(self, *args, **kwargs): + """Initialise chainer.""" super().__init__(*args, **kwargs) def replace_path(path: str, bgsub: bool = ""): @@ -34,7 +36,7 @@ class Chainer(Signal): path = re.sub(channel, f"{channel}{suffix}", path) return path - # Add chain with and without bgsub for composite statistics + # add chain with and without bgsub for composite statistics self.common_chains = { alias + bgsub: lambda **kwargs: self.get( diff --git a/src/postprocessor/core/lineageprocess.py b/src/postprocessor/core/lineageprocess.py index a3d1f91918f253629e73bef218d3292cda50eafc..c359df8950806b5e56c0b3530923eb2585d2a805 100644 --- a/src/postprocessor/core/lineageprocess.py +++ b/src/postprocessor/core/lineageprocess.py @@ -12,19 +12,20 @@ from postprocessor.core.abc import PostProcessABC class LineageProcessParameters(ParametersABC): - """ - Parameters - """ + """Parameters - none are necessary.""" _defaults = {} class LineageProcess(PostProcessABC): """ - Lineage process that must be passed a (N,3) lineage matrix (where the columns are trap, mother, daughter respectively) + To analyse lineage data. + + Currently bare bones, but extracts lineage information from a Signal or Cells object. """ def __init__(self, parameters: LineageProcessParameters): + """Initialise using PostProcessABC.""" super().__init__(parameters) @abstractmethod @@ -34,6 +35,7 @@ class LineageProcess(PostProcessABC): lineage: np.ndarray, *args, ): + """Implement method required by PostProcessABC - undefined.""" pass @classmethod @@ -45,8 +47,9 @@ class LineageProcess(PostProcessABC): **kwargs, ): """ - Overrides PostProcess.as_function classmethod. - Lineage functions require lineage information to be passed if run as function. + Override PostProcesABC.as_function method. + + Lineage functions require lineage information to be run as functions. """ parameters = cls.default_parameters(**kwargs) return cls(parameters=parameters).run( @@ -54,8 +57,9 @@ class LineageProcess(PostProcessABC): ) def get_lineage_information(self, signal=None, merged=True): - + """Get lineage as an array with tile IDs, mother labels, and corresponding bud labels.""" if signal is not None and "mother_label" in signal.index.names: + # from kymograph lineage = get_index_as_np(signal) elif hasattr(self, "lineage"): lineage = self.lineage @@ -68,5 +72,5 @@ class LineageProcess(PostProcessABC): elif self.cells is not None: lineage = self.cells.mothers_daughters else: - raise Exception("No linage information found") + raise Exception("No lineage information found") return lineage diff --git a/src/postprocessor/core/processor.py b/src/postprocessor/core/processor.py index 72bbf3f72e0c94c56c0eaafae715a9d116eedf24..76be13cd9f7e6e339e2a5d02570ac5c21cccc8f8 100644 --- a/src/postprocessor/core/processor.py +++ b/src/postprocessor/core/processor.py @@ -1,3 +1,4 @@ +# change "prepost" to "preprocess"; change filename to postprocessor_engine.py ?? import typing as t from itertools import takewhile @@ -61,34 +62,22 @@ class PostProcessorParameters(ParametersABC): kind: list of str If "ph_batman" included, add targets for experiments using pHlourin. """ - # each subitem specifies the function to be called and the location - # on the h5 file to be written + # each subitem specifies the function to be called + # and the h5-file location for the results + #: why does merger have a string and picker a list? targets = { "prepost": { "merger": "/extraction/general/None/area", "picker": ["/extraction/general/None/area"], }, "processes": [ - [ - "buddings", - ["/extraction/general/None/volume"], - ], - [ - "dsignal", - [ - "/extraction/general/None/volume", - ], - ], - [ - "bud_metric", - [ - "/extraction/general/None/volume", - ], - ], + ["buddings", ["/extraction/general/None/volume"]], + ["dsignal", ["/extraction/general/None/volume"]], + ["bud_metric", ["/extraction/general/None/volume"]], [ "dsignal", [ - "/postprocessing/bud_metric/extraction_general_None_volume", + "/postprocessing/bud_metric/extraction_general_None_volume" ], ], ], @@ -129,7 +118,7 @@ class PostProcessorParameters(ParametersABC): class PostProcessor(ProcessABC): def __init__(self, filename, parameters): """ - Initialise PostProcessor + Initialise PostProcessor. Parameters ---------- @@ -172,31 +161,32 @@ class PostProcessor(ProcessABC): self.targets = parameters["targets"] def run_prepost(self): - """Using picker, get and write lineages, returning mothers and daughters.""" - """Important processes run before normal post-processing ones""" + """ + Run picker and merger and get lineages. + + Necessary before any processes can run. + """ + # run merger record = self._signal.get_raw(self.targets["prepost"]["merger"]) merges = np.array(self.merger.run(record), dtype=int) - self._writer.write( "modifiers/merges", data=[np.array(x) for x in merges] ) - + # get lineages from picker lineage = _assoc_indices_to_3d(self.picker.cells.mothers_daughters) lineage_merged = [] - - if merges.any(): # Update lineages after merge events - + if merges.any(): + # update lineages after merge events merged_indices = merge_association(lineage, merges) - # Remove repeated labels post-merging + # remove repeated labels post-merging lineage_merged = np.unique(merged_indices, axis=0) - self.lineage = _3d_index_to_2d( lineage_merged if len(lineage_merged) else lineage ) self._writer.write( "modifiers/lineage_merged", _3d_index_to_2d(lineage_merged) ) - + # run picker picked_indices = self.picker.run( self._signal[self.targets["prepost"]["picker"][0]] ) @@ -211,26 +201,17 @@ class PostProcessor(ProcessABC): overwrite="overwrite", ) - @staticmethod - def pick_mother(a, b): - """Update the mother id following this priorities: - - The mother has a lower id - """ - x = max(a, b) - if min([a, b]): - x = [a, b][np.argmin([a, b])] - return x - def run(self): """ Write the results to the h5 file. + Processes include identifying buddings and finding bud metrics. """ # run merger, picker, and find lineages self.run_prepost() # run processes for process, datasets in tqdm(self.targets["processes"]): + # process is a str; datasets is a list of str if process in self.parameters["param_sets"].get("processes", {}): # parameters already assigned parameters = self.parameters_classfun[process]( @@ -243,13 +224,12 @@ class PostProcessor(ProcessABC): loaded_process = self.classfun[process](parameters) if isinstance(parameters, LineageProcessParameters): loaded_process.lineage = self.lineage - # apply process to each dataset for dataset in datasets: self.run_process(dataset, process, loaded_process) def run_process(self, dataset, process, loaded_process): - """Run process on a single dataset and write the result.""" + """Run process to obtain a single dataset and write the result.""" # define signal if isinstance(dataset, list): # multisignal process @@ -269,7 +249,7 @@ class PostProcessor(ProcessABC): [], columns=signal.columns, index=signal.index ) result.columns.names = ["timepoint"] - # define outpath, where result will be written + # define outpath to write result if process in self.parameters["outpaths"]: outpath = self.parameters["outpaths"][process] elif isinstance(dataset, list): @@ -318,3 +298,15 @@ class PostProcessor(ProcessABC): metadata: t.Dict, ): self._writer.write(path, result, meta=metadata, overwrite="overwrite") + + @staticmethod + def pick_mother(a, b): + """ + Update the mother id following this priorities: + + The mother has a lower id + """ + x = max(a, b) + if min([a, b]): + x = [a, b][np.argmin([a, b])] + return x diff --git a/src/postprocessor/core/reshapers/bud_metric.py b/src/postprocessor/core/reshapers/bud_metric.py index b06bc5b2af3afee727fda2429d42e72b7e5a974e..b8952288c237dd4420acfc25921a95e0afcc8ad1 100644 --- a/src/postprocessor/core/reshapers/bud_metric.py +++ b/src/postprocessor/core/reshapers/bud_metric.py @@ -1,5 +1,4 @@ import typing as t -from typing import Dict, Tuple import numpy as np import pandas as pd @@ -31,7 +30,7 @@ class BudMetric(LineageProcess): def run( self, signal: pd.DataFrame, - lineage: Dict[pd.Index, Tuple[pd.Index]] = None, + lineage: t.Dict[pd.Index, t.Tuple[pd.Index]] = None, ): if lineage is None: if hasattr(self, "lineage"): @@ -44,12 +43,12 @@ class BudMetric(LineageProcess): @staticmethod def get_bud_metric( - signal: pd.DataFrame, md: Dict[Tuple, Tuple[Tuple]] = None + signal: pd.DataFrame, md: t.Dict[t.Tuple, t.Tuple[t.Tuple]] = None ): """ signal: Daughter-inclusive dataframe - md: Mother-daughters dictionary where key is mother's index and value a list of daugher indices + md: Mother-daughters dictionary where key is mother's index and its values are a list of daughter indices Get fvi (First Valid Index) for all cells Create empty matrix diff --git a/src/postprocessor/core/reshapers/buddings.py b/src/postprocessor/core/reshapers/buddings.py index acdf0165f4554a982facc3f3fbd165dd3d3db107..4398b9478c62a896322c7ce8fd39ee6500bc6641 100644 --- a/src/postprocessor/core/reshapers/buddings.py +++ b/src/postprocessor/core/reshapers/buddings.py @@ -13,16 +13,10 @@ from postprocessor.core.lineageprocess import ( class buddingsParameters(LineageProcessParameters): - """Parameter class to obtain budding events. - - Parameters - ---------- - LineageProcessParameters : lineage_location - Location of lineage matrix to be used for calculations. + """ + Parameter class to obtain budding events. - Examples - -------- - FIXME: Add docs. + Define the location of lineage information in the h5 file. """ @@ -31,45 +25,44 @@ class buddingsParameters(LineageProcessParameters): class buddings(LineageProcess): """ - Calculate buddings in a trap assuming one mother per trap - returns a pandas series with the buddings. + Calculate buddings in a trap assuming one mother per trap. + + Return a pandas series with the buddings. - We define a budding event as the moment in which a bud was identified for - the first time, even if the bud is not considered one until later - in the experiment. + We define a budding event as when a bud is first identified. + + This bud may not be considered a bud until later in the experiment. """ def __init__(self, parameters: buddingsParameters): + """Initialise buddings.""" super().__init__(parameters) def run( self, signal: pd.DataFrame, lineage: np.ndarray = None ) -> pd.DataFrame: + """TODO.""" lineage = lineage or self.lineage - - # Get time of first appearance for all cells - fvi = signal.apply(lambda x: x.first_valid_index(), axis=1) - - # Select mother cells in a given dataset + # select traps and mother cells in a given signal traps_mothers: t.Dict[tuple, list] = { tuple(mo): [] for mo in lineage[:, :2] if tuple(mo) in signal.index } for trap, mother, daughter in lineage: if (trap, mother) in traps_mothers.keys(): traps_mothers[(trap, mother)].append(daughter) - mothers = signal.loc[ set(signal.index).intersection(traps_mothers.keys()) ] - # Create a new dataframe with dimensions (n_mother_cells * n_timepoints) + # create a new dataframe with dimensions (n_mother_cells * n_timepoints) buddings = pd.DataFrame( np.zeros((mothers.shape[0], signal.shape[1])).astype(bool), index=mothers.index, columns=signal.columns, ) buddings.columns.names = ["timepoint"] - - # Fill the budding events + # get time of first appearance for every cell using Pandas + fvi = signal.apply(lambda x: x.first_valid_index(), axis=1) + # fill the budding events for mother_id, daughters in traps_mothers.items(): daughters_idx = set( fvi.loc[ @@ -82,5 +75,4 @@ class buddings(LineageProcess): mother_id, daughters_idx, ] = True - return buddings diff --git a/src/postprocessor/core/reshapers/picker.py b/src/postprocessor/core/reshapers/picker.py index 1796fe653c32be9db384f5e5f629a479cc04985b..f3c35852ede1d3f3b54bed4692257ba9aeb90d24 100644 --- a/src/postprocessor/core/reshapers/picker.py +++ b/src/postprocessor/core/reshapers/picker.py @@ -12,6 +12,14 @@ from postprocessor.core.lineageprocess import LineageProcess class PickerParameters(ParametersABC): + """ + A dictionary specifying the sequence of picks in order. + + "lineage" is further specified by "mothers", "daughters", "families" (mother-bud pairs), and "orphans", where orphans picks cells that are not in families. + + "condition" is further specified by "present", "continuously_present", "any_present", or "growing" and a threshold, either a number of time points or a fraction of the total duration of the experiment. + """ + _defaults = { "sequence": [ ["lineage", "families"], @@ -22,11 +30,8 @@ class PickerParameters(ParametersABC): class Picker(LineageProcess): """ - :cells: Cell object passed to the constructor - :condition: Tuple with condition and associated parameter(s), conditions can be - "present", "nonstoply_present" or "quantile". - Determine the thresholds or fractions of signals to use. - :lineage: str {"mothers", "daughters", "families" (mothers AND daughters), "orphans"}. Mothers/daughters picks cells with those tags, families pick the union of both and orphans the difference between the total and families. + Picker selects cells from a signal using lineage information and + by how and for how long they are retained in the data set. """ def __init__( @@ -34,6 +39,7 @@ class Picker(LineageProcess): parameters: PickerParameters, cells: Cells or None = None, ): + """Initialise picker.""" super().__init__(parameters=parameters) self.cells = cells @@ -43,9 +49,10 @@ class Picker(LineageProcess): how: str, mothers_daughters: t.Optional[np.ndarray] = None, ) -> pd.MultiIndex: + """Return rows of a signal corresponding to either mothers, daughters, or mother-daughter pairs using lineage information.""" cells_present = drop_mother_label(signal.index) mothers_daughters = self.get_lineage_information(signal) - valid_indices = slice(None) + #: might be better if match_column defined as a string to make everything one line if how == "mothers": _, valid_indices = validate_association( mothers_daughters, cells_present, match_column=0 @@ -54,79 +61,96 @@ class Picker(LineageProcess): _, valid_indices = validate_association( mothers_daughters, cells_present, match_column=1 ) - elif how == "families": # Mothers and daughters that are still present + elif how == "families": + # mother-daughter pairs _, valid_indices = validate_association( mothers_daughters, cells_present ) + else: + valid_indices = slice(None) return signal.index[valid_indices] - def pick_by_condition(self, signal, condition, thresh): - idx = self.switch_case(signal, condition, thresh) - return idx - def run(self, signal): + """ + Pick indices from the index of a signal's dataframe and return as an array. + + Typically, we first pick by lineage, then by condition. + """ self.orig_signal = signal indices = set(signal.index) lineage = self.get_lineage_information(signal) if len(lineage): - self.mothers = lineage[:, :2] + self.mothers = lineage[:, [0, 1]] self.daughters = lineage[:, [0, 2]] for alg, *params in self.sequence: - new_indices = tuple() if indices: if alg == "lineage": + # pick mothers, buds, or mother-bud pairs param1 = params[0] new_indices = getattr(self, "pick_by_" + alg)( signal.loc[list(indices)], param1 ) else: + # pick by condition param1, *param2 = params new_indices = getattr(self, "pick_by_" + alg)( signal.loc[list(indices)], param1, param2 ) - new_indices = [tuple(x) for x in new_indices] + else: + new_indices = tuple() + # number of indices reduces for each iteration of the loop indices = indices.intersection(new_indices) else: - self._log(f"No lineage assignment") + self._log("No lineage assignment") indices = np.array([]) - return np.array([tuple(map(_str_to_int, x)) for x in indices]) + # convert to array + indices_arr = np.array([tuple(map(_str_to_int, x)) for x in indices]) + return indices_arr + + # def pick_by_condition(self, signal, condition, thresh): + # idx = self.switch_case(signal, condition, thresh) + # return idx - def switch_case( + def pick_by_condition( self, signal: pd.DataFrame, condition: str, threshold: t.Union[float, int, list], ): + """Pick indices from signal by any_present, present, continuously_present, and growing.""" if len(threshold) == 1: threshold = [_as_int(*threshold, signal.shape[1])] + #: is this correct for "growing"? case_mgr = { "any_present": lambda s, thresh: any_present(s, thresh), "present": lambda s, thresh: s.notna().sum(axis=1) > thresh, - "nonstoply_present": lambda s, thresh: s.apply(thresh, axis=1) + #: continuously_present looks incorrect + "continuously_present": lambda s, thresh: s.apply(thresh, axis=1) > thresh, "growing": lambda s, thresh: s.diff(axis=1).sum(axis=1) > thresh, } - return set(signal.index[case_mgr[condition](signal, *threshold)]) + # apply condition + idx = set(signal.index[case_mgr[condition](signal, *threshold)]) + new_indices = [tuple(x) for x in idx] + return new_indices def _as_int(threshold: t.Union[float, int], ntps: int): + """Convert a fraction of the total experiment duration into a number of time points.""" if type(threshold) is float: threshold = ntps * threshold return threshold def any_present(signal, threshold): - """ - Return a mask for cells, True if there is a cell in that trap that was present for more than :threshold: timepoints. - """ + """Return pd.Series for cells where True indicates that cell was present for more than threshold time points.""" + #: why isn't full_traps all we need? + full_traps = (signal.notna().sum(axis=1) > threshold).groupby("trap") any_present = pd.Series( np.sum( [ - np.isin([x[0] for x in signal.index], i) & v - for i, v in (signal.notna().sum(axis=1) > threshold) - .groupby("trap") - .any() - .items() + np.isin([x[0] for x in signal.index], i) & full + for i, full in full_traps.any().items() ], axis=0, ).astype(bool), diff --git a/src/postprocessor/grouper.py b/src/postprocessor/grouper.py index 27f905c319b710c343f621272c41da1c035f916e..0844009322d99b119083172df8e791aa6ee32f8a 100644 --- a/src/postprocessor/grouper.py +++ b/src/postprocessor/grouper.py @@ -86,8 +86,9 @@ class Grouper(ABC): **kwargs, ): """ - Concatenate data for one signal from different h5 files, with - one h5 file per position, into a dataframe. + Concatenate data for one signal from different h5 files into a dataframe. + + Each h5 file corresponds to one position Parameters ---------- @@ -267,17 +268,17 @@ class Grouper(ABC): @property def stages_span(self): - # FAILS on my example + # TODO: fails on my example return self.fsignal.stages_span @property def max_span(self): - # FAILS on my example + # TODO: fails on my example return self.fsignal.max_span @property def stages(self): - # FAILS on my example + # TODO: fails on my example return self.fsignal.stages @property