Skip to content
Snippets Groups Projects
Forked from Swain Lab / aliby / aliby-mirror
7 commits behind, 45 commits ahead of the upstream repository.
indexing_new.py 9.44 KiB
#!/usr/bin/env jupyter
"""
Utilities based on association are used to efficiently acquire indices of
tracklets with some kind of relationship.
This can be:
    - Cells that are to be merged.
    - Cells that have a lineage relationship.
"""

import numpy as np
import typing as t


def validate_association(
    association: np.ndarray,
    indices: np.ndarray,
    match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
    """
    Identify mother-bud pairs that exist both in lineage and a Signal's indices.

    """
    if association.ndim == 2:
        # reshape into 3D array for broadcasting
        # for each trap, [trap, mother, daughter] becomes
        #  [[trap, mother], [trap, daughter]]
        association = _assoc_indices_to_3d(association)

    valid_association, valid_indices = validate_lineage(association, indices)

    # Alan's working code
    # Compare existing association with available indices
    # Swap trap and label axes for the association array to correctly cast
    valid_ndassociation_a = association[..., None] == indices.T[None, ...]
    # Broadcasting is confusing (but efficient):
    # First we check the dimension across trap and cell id, to ensure both match
    valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
    if match_column is None:
        # Then we check the merge tuples to check which cases have both target and source
        valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)

        # Finally we check the dimension that crosses all indices, to ensure the pair
        # is present in a valid merge event.
        valid_indices_a = (
            valid_ndassociation_a[valid_association_a]
            .all(axis=2)
            .any(axis=(0, 1))
        )
    else:  # We fetch specific indices if we aim for the ones with one present
        valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
        # Valid association then becomes a boolean array, true means that there is a
        # match (match_column) between that cell and the index
        valid_association_a = (
            valid_cell_ids_a[:, match_column] & valid_indices
        ).any(axis=1)

    assert np.array_equal(
        valid_association, valid_association_a
    ), "valid_association error"
    assert np.array_equal(
        valid_indices, valid_indices_a
    ), "valid_indices error"

    return valid_association, valid_indices


def validate_association_old(
    association: np.ndarray,
    indices: np.ndarray,
    match_column: t.Optional[int] = None,
) -> t.Tuple[np.ndarray, np.ndarray]:
    """
    Identify mother-bud pairs that exist both in lineage and a Signal's indices.


    Parameters
    ----------
    association : np.ndarray
        2D array of lineage associations where columns are (trap, mother, daughter)
        or
        a 3D array, which is an array of 2 X 2 arrays comprising [[trap_id, mother_label], [trap_id, daughter_label]].
    indices : np.ndarray
        A 2D array where each column is a different level, such as (trap_id, cell_label), which typically is an index of a Signal
        dataframe. This array should not include mother_label.
    match_column: int
        If 0, matches indicate mothers from mother-bud pairs;
        If 1, matches indicate daughters from mother-bud pairs;
        If None, matches indicate either mothers or daughters in mother-bud pairs.

    Returns
    -------
    valid_association: boolean np.ndarray
        1D array indicating elements in association with matches.
    valid_indices: boolean np.ndarray
        1D array indicating elements in indices with matches.

    Examples
    --------
    >>> import numpy as np
    >>> from agora.utils.indexing import validate_association

    >>> association = np.array([ [[0, 1], [0, 3]], [[0, 1], [0, 4]], [[0, 1], [0, 6]], [[0, 4], [0, 7]] ])
    >>> indices = np.array([ [0, 1], [0, 2], [0, 3]])
    >>> print(indices.T)

    >>> valid_association, valid_indices = validate_association(association, indices)

    >>> print(valid_association)
     array([ True, False, False, False])
    >>> print(valid_indices)
     array([ True, False, True])

    and

    >>> association = np.array([[[0,3], [0,1]], [[0,2], [0,4]]])
    >>> indices = np.array([[0,1], [0,2], [0,3]])
    >>> valid_association, valid_indices = validate_association(association, indices)
    >>> print(valid_association)
     array([ True, False])
    >>> print(valid_indices)
     array([ True, False, True])
    """
    if association.ndim == 2:
        # reshape into 3D array for broadcasting
        # for each trap, [trap, mother, daughter] becomes
        #  [[trap, mother], [trap, daughter]]
        association = _assoc_indices_to_3d(association)
    # use broadcasting to compare association with indices
    # swap trap and cell_label axes for correct broadcasting
    indicesT = indices.T
    # compare each of [[trap, mother], [trap, daughter]] for all traps
    # in association with [trap, cell_label] for all traps in indices
    # association is no_traps x 2 x 2; indices is no_traps X 2
    # valid_ndassociation is no_traps_association x 2 x 2 x no_traps_indices
    valid_ndassociation = (
        association[..., np.newaxis] == indicesT[np.newaxis, ...]
    )
    # find matches in association
    ###
    # make True comparisons with both trap_ids and cell labels matching
    # compare trap_ids and cell_ids for each pair of traps
    valid_cell_ids = valid_ndassociation.all(axis=2)
    if match_column is None:
        # make True comparisons match at least one row in indices
        # at least one cell_id matches
        va_intermediate = valid_cell_ids.any(axis=2)
        # make True comparisons have both mother and bud matching rows in indices
        valid_association = va_intermediate.all(axis=1)
    else:
        # match_column selects mothers if 0 and daughters if 1
        # make True match at least one row in indices
        valid_association = valid_cell_ids[:, match_column].any(axis=1)
    # find matches in indices
    ###
    # make True comparisons have a validated association for both the mother and bud
    # make True comparisons have both trap_ids and cell labels matching
    valid_cell_ids_va = valid_ndassociation[valid_association].all(axis=2)
    if match_column is None:
        # make True comparisons match either a mother or a bud in association
        valid_indices = valid_cell_ids_va.any(axis=(0, 1))
    else:
        valid_indices = valid_cell_ids_va[:, match_column][0]

    # Alan's working code
    # Compare existing association with available indices
    # Swap trap and label axes for the association array to correctly cast
    valid_ndassociation_a = association[..., None] == indices.T[None, ...]
    # Broadcasting is confusing (but efficient):
    # First we check the dimension across trap and cell id, to ensure both match
    valid_cell_ids_a = valid_ndassociation_a.all(axis=2)
    if match_column is None:
        # Then we check the merge tuples to check which cases have both target and source
        valid_association_a = valid_cell_ids_a.any(axis=2).all(axis=1)

        # Finally we check the dimension that crosses all indices, to ensure the pair
        # is present in a valid merge event.
        valid_indices_a = (
            valid_ndassociation_a[valid_association_a]
            .all(axis=2)
            .any(axis=(0, 1))
        )
    else:  # We fetch specific indices if we aim for the ones with one present
        valid_indices_a = valid_cell_ids_a[:, match_column].any(axis=0)
        # Valid association then becomes a boolean array, true means that there is a
        # match (match_column) between that cell and the index
        valid_association_a = (
            valid_cell_ids_a[:, match_column] & valid_indices
        ).any(axis=1)

    assert np.array_equal(
        valid_association, valid_association_a
    ), "valid_association error"
    assert np.array_equal(
        valid_indices, valid_indices_a
    ), "valid_indices error"

    return valid_association, valid_indices


def _assoc_indices_to_3d(ndarray: np.ndarray):
    """
    Reorganise an array of shape (N, 3) into one of shape (N, 2, 2).

    Reorganise an array so that the last entry of each row is removed
    and generates a new row. This new row retains all other entries of
    the original row.

    Example:
    [ [0, 1, 3], [0, 1, 4] ]
    becomes
    [ [[0, 1], [0, 3]], [[0, 1], [0, 4]] ]
    """
    result = ndarray
    if len(ndarray) and ndarray.ndim > 1:
        if ndarray.shape[1] == 3:
            # faster indexing for single positions
            result = np.transpose(
                np.hstack((ndarray[:, [0]], ndarray)).reshape(-1, 2, 2),
                axes=[0, 2, 1],
            )
        else:
            # 20% slower, but more general indexing
            columns = np.arange(ndarray.shape[1])
            result = np.stack(
                (
                    ndarray[:, np.delete(columns, -1)],
                    ndarray[:, np.delete(columns, -2)],
                ),
                axis=1,
            )
    return result


def _3d_index_to_2d(array: np.ndarray):
    """Revert switch from _assoc_indices_to_3d."""
    result = array
    if len(array):
        result = np.concatenate(
            (array[:, 0, :], array[:, 1, 1, np.newaxis]), axis=1
        )
    return result


def compare_indices(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """
    Compare two 2D arrays using broadcasting.

    Return a binary array where a True value links two cells where
    all cells are the same.
    """
    return (x[..., np.newaxis] == y.T[np.newaxis, ...]).all(axis=1)