#!/usr/bin/env jupyter
"""
Functions to efficiently merge rows in DataFrames.
"""
import typing as t
from copy import copy

import numpy as np
import pandas as pd
from utils_find_1st import cmp_larger, find_1st

from agora.utils.indexing import (
    index_isin,
    compare_indices,
    validate_association,
)


def apply_merges(data: pd.DataFrame, merges: np.ndarray):
    """
    Split data in two, one subset for rows relevant for merging and one
    without them.

    Use an array of source tracklets and target tracklets
    to efficiently merge them.

    Parameters
    ----------
    data : pd.DataFrame
        Input DataFrame.
    merges : np.ndarray
        3-D ndarray where dimensions are (X,2,2): nmerges, source-target
        pair and single-cell identifiers, respectively.

    Examples
    --------
    FIXME: Add docs.

    """

    indices = data.index
    if "mother_label" in indices.names:
        indices = indices.droplevel("mother_label")
    valid_merges, indices = validate_association(
        merges, np.array(list(indices))
    )

    # Assign non-merged
    merged = data.loc[~indices]

    # Implement the merges and drop source rows.
    # TODO Use matrices to perform merges in batch
    # for efficiency
    if valid_merges.any():
        to_merge = data.loc[indices].copy()
        targets, sources = zip(*merges[valid_merges])
        for source, target in zip(sources, targets):
            target = tuple(target)
            to_merge.loc[target] = join_tracks_pair(
                to_merge.loc[target].values,
                to_merge.loc[tuple(source)].values,
            )
        to_merge.drop(map(tuple, sources), inplace=True)

        merged = pd.concat((merged, to_merge), names=data.index.names)
    return merged


def join_tracks_pair(target: np.ndarray, source: np.ndarray) -> np.ndarray:
    """
    Join two tracks and return the new value of the target.
    """
    target_copy = target
    end = find_1st(target_copy[::-1], 0, cmp_larger)
    target_copy[-end:] = source[-end:]
    return target_copy


def group_merges(merges: np.ndarray) -> t.List[t.Tuple]:
    """
    Convert merges into a list of merges for traps requiring multiple
    merges and then for traps requiring single merges.
    """
    left_tracks = merges[:, 0]
    right_tracks = merges[:, 1]
    # find traps requiring multiple merges
    linr = merges[index_isin(left_tracks, right_tracks).flatten(), :]
    rinl = merges[index_isin(right_tracks, left_tracks).flatten(), :]
    # make unique and order merges for each trap
    multi_merge = np.unique(np.concatenate((linr, rinl)), axis=0)
    # find traps requiring a singe merge
    single_merge = merges[
        ~index_isin(merges, multi_merge).all(axis=1).flatten(), :
    ]
    # convert to lists of arrays
    single_merge_list = [[sm] for sm in single_merge]
    multi_merge_list = [
        multi_merge[multi_merge[:, 0, 0] == trap_id, ...]
        for trap_id in np.unique(multi_merge[:, 0, 0])
    ]
    res = [*multi_merge_list, *single_merge_list]
    return res


def union_find(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = True
    while merged:
        merged = False
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = True
                    common |= x
            results.append(common)
        sets = results
    return sets


def sort_association(array: np.ndarray):
    # Sort the internal associations

    order = np.where(
        (array[:, 0, ..., None] == array[:, 1].T[None, ...]).all(axis=1)
    )

    res = []
    [res.append(x) for x in np.flip(order).flatten() if x not in res]
    sorted_array = array[np.array(res)]
    return sorted_array


def merge_association(lineage: np.ndarray, merges: np.ndarray) -> np.ndarray:
    """Use merges to update lineage information."""
    flat_lineage = lineage.reshape(-1, 2)
    bud_mother_dict = {
        tuple(bud): mother for bud, mother in zip(lineage[:, 1], lineage[:, 0])
    }
    left_tracks = merges[:, 0]
    # find left tracks that are in lineages
    valid_lineages = index_isin(flat_lineage, left_tracks).flatten()
    # group into multi- and then single merges
    grouped_merges = group_merges(merges)
    # perform merges
    if valid_lineages.any():
        # indices of each left track -> indices of rightmost right track
        replacement_dict = {
            tuple(contig_pair[0]): merge[-1][1]
            for merge in grouped_merges
            for contig_pair in merge
        }
        # if both key and value are buds, they must have the same mother
        buds = lineage[:, 1]
        incorrect_merges = [
            key
            for key in replacement_dict
            if np.any(index_isin(buds, replacement_dict[key]).flatten())
            and np.any(index_isin(buds, key).flatten())
            and not np.array_equal(
                bud_mother_dict[key],
                bud_mother_dict[tuple(replacement_dict[key])],
            )
        ]
        # reassign incorrect merges so that they have no affect
        for key in incorrect_merges:
            replacement_dict[key] = key
        # correct lineage information
        # replace mother or bud index with index of rightmost track
        flat_lineage[valid_lineages] = [
            replacement_dict[tuple(index)]
            for index in flat_lineage[valid_lineages]
        ]
    # reverse flattening
    new_lineage = flat_lineage.reshape(-1, 2, 2)
    # remove any duplicates
    new_lineage = np.unique(new_lineage, axis=0)
    return new_lineage