import logging
import typing as t
from copy import copy
from functools import cached_property, lru_cache
from pathlib import Path

import bottleneck as bn
import h5py
import numpy as np
import pandas as pd

import aliby.global_parameters as global_parameters
from agora.io.bridge import BridgeH5
from agora.io.decorators import _first_arg_str_to_raw_df
from agora.utils.indexing import validate_lineage
from agora.utils.kymograph import add_index_levels
from agora.utils.merge import apply_merges


class Signal(BridgeH5):
    """
    Fetch data from h5 files for post-processing.

    Signal assumes that the metadata and data are accessible to
    perform time-adjustments and apply previously recorded
    post-processes.
    """

    def __init__(self, file: t.Union[str, Path]):
        """Initialise defining index names for the dataframe."""
        super().__init__(file, flag=None)
        self.index_names = (
            "experiment",
            "position",
            "trap",
            "cell_label",
            "mother_label",
        )
        self.candidate_channels = global_parameters.possible_imaging_channels

    def get(
        self,
        dset: t.Union[str, t.Collection],
        tmax_in_mins: int = None,
    ):
        """Get Signal after merging and picking."""
        if isinstance(dset, str):
            record = self.get_raw(dset, tmax_in_mins=tmax_in_mins)
            if record is not None:
                picked_merged = self.apply_merging_picking(record)
                return self.add_name(picked_merged, dset)
        elif isinstance(dset, list):
            return [self.get(d) for d in dset]
        else:
            raise Exception("Error in Signal.get")

    @staticmethod
    def add_name(df, name):
        """Add name of the Signal as an attribute to its data frame."""
        df.name = name
        return df

    def cols_in_mins(self, df: pd.DataFrame):
        """Convert numerical columns in a data frame to minutes."""
        df.columns = (df.columns * self.tinterval // 60).astype(int)
        return df

    @cached_property
    def ntimepoints(self):
        """Find the number of time points for one position, or one h5 file."""
        with h5py.File(self.filename, "r") as f:
            return f["extraction/general/None/area/timepoint"][-1] + 1

    @cached_property
    def tinterval(self) -> int:
        """Find the interval between time points (seconds)."""
        tinterval_location = "time_settings/timeinterval"
        with h5py.File(self.filename, "r") as f:
            if tinterval_location in f.attrs:
                res = f.attrs[tinterval_location]
                if type(res) is list:
                    return res[0]
                else:
                    return res
            else:
                logging.getLogger("aliby").warn(
                    f"{str(self.filename).split('/')[-1]}: using default time interval of 300 seconds."
                )
                return 300

    def retained(self, signal, cutoff: float = 0, tmax_in_mins: int = None):
        """Get retained cells for a Signal or list of Signals."""
        if isinstance(signal, str):
            # get data frame
            signal = self.get(signal, tmax_in_mins=tmax_in_mins)
        if isinstance(signal, pd.DataFrame):
            return self.get_retained(signal, cutoff)
        elif isinstance(signal, list):
            return [self.get_retained(d, cutoff=cutoff) for d in signal]

    @staticmethod
    def get_retained(df, cutoff):
        """
        Return sub data frame with retained cells.

        Cells must be present for at least cutoff fraction of the total number
        of time points.
        """
        return df.loc[bn.nansum(df.notna(), axis=1) > df.shape[1] * cutoff]

    @property
    def channels(self) -> t.Collection[str]:
        """Get channels as an array of strings."""
        with h5py.File(self.filename, "r") as f:
            return list(f.attrs["channels"])

    @lru_cache(2)
    def lineage(
        self, lineage_location: t.Optional[str] = None, merged: bool = False
    ) -> np.ndarray:
        """
        Get lineage data from a given location in the h5 file.

        Returns an array with three columns: the tile id, the mother label,
        and the daughter label.
        """
        if lineage_location is None:
            lineage_location = "modifiers/lineage_merged"
        with h5py.File(self.filename, "r") as f:
            if lineage_location not in f:
                lineage_location = "postprocessing/lineage"
            traps_mothers_daughters = f[lineage_location]
            if isinstance(traps_mothers_daughters, h5py.Dataset):
                lineage = traps_mothers_daughters[()]
            else:
                lineage = np.array(
                    (
                        traps_mothers_daughters["trap"],
                        traps_mothers_daughters["mother_label"],
                        traps_mothers_daughters["daughter_label"],
                    )
                ).T
        return lineage

    # @_first_arg_str_to_raw_df
    def apply_merging_picking(
        self,
        data: t.Union[str, pd.DataFrame],
        merges: t.Union[np.ndarray, bool] = True,
        picks: t.Union[t.Collection, bool] = True,
    ):
        """
        Apply picking and merging to a Signal data frame.

        Parameters
        ----------
        data : t.Union[str, pd.DataFrame]
            A data frame or a path to one.
        merges : t.Union[np.ndarray, bool]
            (optional) An array of pairs of (trap, cell) indices to merge.
            If True, fetch merges from file.
        picks : t.Union[np.ndarray, bool]
            (optional) An array of (trap, cell) indices.
            If True, fetch picks from file.
        """
        if isinstance(merges, bool):
            merges = self.load_merges() if merges else np.array([])
        if merges.any():
            merged = apply_merges(data, merges)
        else:
            merged = copy(data)
        if isinstance(picks, bool):
            if picks is True:
                # load picks from h5
                picks = self.get_picks(
                    names=merged.index.names, path="modifiers/picks/"
                )
            else:
                return merged
        if len(picks):
            picked_indices = list(
                set(picks).intersection([tuple(x) for x in merged.index])
            )
            return merged.loc[picked_indices]
        else:
            return merged

    @cached_property
    def print_available(self):
        """Print data sets available in h5 file."""
        if not hasattr(self, "_available"):
            self._available = []
            with h5py.File(self.filename, "r") as f:
                f.visititems(self.store_signal_path)
        for sig in self._available:
            print(sig)

    @cached_property
    def available(self):
        """Get data sets available in h5 file."""
        try:
            if not hasattr(self, "_available"):
                self._available = []
            with h5py.File(self.filename, "r") as f:
                f.visititems(self.store_signal_path)
        except Exception as e:
            self.log("Exception when visiting h5: {}".format(e), "exception")
        return self._available

    def get_merged(self, dataset):
        """Run merging."""
        return self.apply_merging_picking(dataset, picks=False)

    @cached_property
    def merges(self) -> np.ndarray:
        """Get merges."""
        with h5py.File(self.filename, "r") as f:
            dsets = f.visititems(self._if_merges)
        return dsets

    @cached_property
    def n_merges(self):
        """Get number of merges."""
        return len(self.merges)

    @cached_property
    def picks(self) -> np.ndarray:
        """Get picks."""
        with h5py.File(self.filename, "r") as f:
            dsets = f.visititems(self._if_picks)
        return dsets

    def get_raw(
        self,
        dataset: str or t.List[str],
        in_minutes: bool = True,
        lineage: bool = False,
        tmax_in_mins: int = None,
        **kwargs,
    ) -> pd.DataFrame or t.List[pd.DataFrame]:
        """
        Get raw Signal without merging, picking, and lineage information.

        Parameters
        ----------
        dataset: str or list of strs
            The name of the h5 file or a list of h5 file names.
        in_minutes: boolean
            If True, convert column headings to times in minutes.
        lineage: boolean
            If True, add mother_label to index.
        run_lineage_check: boolean
            If True, raise exception if a likely error in the lineage assignment.
        tmax_in_mins: int (optional)
            Discard data for times > tmax_in_mins. Cells with all NaNs will also
            be discarded to help with assigning lineages.
            Setting tmax_in_mins is a way to ignore parts of the experiment with
            incorrect lineages generated by clogging.
        """
        if isinstance(dataset, str):
            with h5py.File(self.filename, "r") as f:
                df = self.dataset_to_df(f, dataset)
                if df is not None:
                    df = df.sort_index()
                    if in_minutes:
                        df = self.cols_in_mins(df)
                    # limit data by time and discard NaNs
                    if (
                        in_minutes
                        and tmax_in_mins
                        and type(tmax_in_mins) is int
                    ):
                        df = df[df.columns[df.columns <= tmax_in_mins]]
                        df = df.dropna(how="all")
                    # add mother label to data frame
                    if lineage:
                        if "mother_label" in df.index.names:
                            df = df.droplevel("mother_label")
                        mother_label = np.zeros(len(df), dtype=int)
                        lineage = self.lineage()
                        (
                            valid_lineage,
                            valid_indices,
                            lineage,
                        ) = validate_lineage(
                            lineage,
                            indices=np.array(df.index.to_list()),
                            how="daughters",
                        )
                        mother_label[valid_indices] = lineage[valid_lineage, 1]
                        df = add_index_levels(
                            df, {"mother_label": mother_label}
                        )
                    return df
        elif isinstance(dataset, list):
            return [
                self.get_raw(
                    dset,
                    in_minutes=in_minutes,
                    lineage=lineage,
                    tmax_in_mins=tmax_in_mins,
                )
                for dset in dataset
            ]

    def load_merges(self):
        """Get merge events going up to the first level."""
        with h5py.File(self.filename, "r") as f:
            merges = f.get("modifiers/merges", np.array([]))
            if not isinstance(merges, np.ndarray):
                merges = merges[()]
        return merges

    def get_picks(
        self,
        names: t.Tuple[str, ...] = ("trap", "cell_label"),
        path: str = "modifiers/picks/",
    ) -> t.Set[t.Tuple[int, str]]:
        """Get picks from the h5 file."""
        with h5py.File(self.filename, "r") as f:
            if path in f:
                picks = set(
                    zip(*[f[path + name] for name in names if name in f[path]])
                )
            else:
                picks = set()
            return picks

    def dataset_to_df(self, f: h5py.File, path: str) -> pd.DataFrame:
        """Get data from h5 file as a dataframe."""
        if path not in f:
            self.log(f"{path} not in {f}.")
            return None
        else:
            dset = f[path]
            values, index, columns = [], [], []
            index_names = copy(self.index_names)
            valid_names = [lbl for lbl in index_names if lbl in dset.keys()]
            if valid_names:
                index = pd.MultiIndex.from_arrays(
                    [dset[lbl] for lbl in valid_names], names=valid_names
                )
                columns = dset.attrs.get("columns", None)
                if "timepoint" in dset:
                    columns = f[path + "/timepoint"][()]
                values = f[path + "/values"][()]
            df = pd.DataFrame(values, index=index, columns=columns)
            return df

    @property
    def stem(self):
        """Get name of h5 file."""
        return self.filename.stem

    def store_signal_path(
        self,
        fullname: str,
        node: t.Union[h5py.Dataset, h5py.Group],
    ):
        """Store the name of a signal if it is a leaf node and if it starts with extraction."""
        if isinstance(node, h5py.Group) and np.all(
            [isinstance(x, h5py.Dataset) for x in node.values()]
        ):
            self._if_ext_or_post(fullname, self._available)

    @staticmethod
    def _if_ext_or_post(name: str, siglist: list):
        if name.startswith("extraction") or name.startswith("postprocessing"):
            siglist.append(name)

    @staticmethod
    def _if_merges(name: str, obj):
        if isinstance(obj, h5py.Dataset) and name.startswith(
            "modifiers/merges"
        ):
            return obj[()]

    @staticmethod
    def _if_picks(name: str, obj):
        if isinstance(obj, h5py.Group) and name.endswith("picks"):
            return obj[()]

    @property
    def ntps(self) -> int:
        """Get number of time points from the metadata."""
        return self.meta_h5["time_settings/ntimepoints"][0]