from copy import copy

import matplotlib.cm
import matplotlib.pylab as plt
import numpy as np
import numpy.matlib


def kymograph(
    df,
    hue="median_GFP",
    x="time",
    y="id",
    xtick_step_in_hours=5,
    vmax=None,
    vmin=None,
    cmap=matplotlib.cm.Greens,
    figsize=(6, 10),
    title=None,
    returnfig=False,
):
    """
    Plot a heatmap.

    Typically each row is a single cell and the x-axis shows time.

    Time is assumed to be in hours.

    Examples
    --------
    >>> from wela.plotting import kymograph
    >>> kymograph(dl.df, hue="median_GFP")
    >>> kymograph(dl.df, hue="bud_volume", title="2% Gal")
    >>> kymograph(dl.df, hue="buddings")
    """
    if hue == "buddings":
        cmap = "Greys"
    wdf = df.pivot(y, x, hue)
    dt = np.min(np.diff(np.sort(df.time.unique())))
    # from Arin
    data = wdf.to_numpy()
    # define horizontal axis ticks and labels
    xtick_min = 0
    xtick_max = dt * data.shape[1]
    xticklabels = np.arange(xtick_min, xtick_max, xtick_step_in_hours)
    xticks = [
        np.argmin((dt * np.arange(data.shape[1]) - label) ** 2)
        for label in xticklabels
    ]
    xticklabels = list(map(str, xticklabels.tolist()))
    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_title(title)
    trace_heatmap = ax.imshow(
        data,
        cmap=cmap,
        interpolation="none",
        vmax=vmax,
        vmin=vmin,
    )
    ax.figure.colorbar(
        mappable=trace_heatmap,
        ax=ax,
        label=hue,
        aspect=50,
        shrink=0.5,
    )
    plt.tight_layout()
    plt.show(block=False)
    if returnfig:
        return fig, ax


def plot_random_time_series(time, values, signalname=None, number=5):
    """Plot random time series on mouse click and terminates on a key press."""
    fig = plt.figure()
    go = True

    def no_go(event):
        nonlocal go
        go = False

    # stop while loop if there is key press event
    fig.canvas.mpl_connect("key_press_event", no_go)
    # loop until a key is pressed
    while go:
        irnds = np.random.randint(0, values.shape[0], number)
        plt.cla()
        for i in irnds:
            plt.plot(time / 60, values[i, :], label=str(i))
        plt.xlabel("time (hours)")
        if signalname:
            plt.ylabel(signalname)
        plt.legend()
        plt.grid()
        plt.draw()
        while plt.waitforbuttonpress(0.2) is None:
            # wait for a key or mouse button
            # if a key, no_go called through mpl_connect
            pass
        print(".")


def plot_lineages(
    irange,
    dl,
    signals=["volume", "growth_rate"],
    show=True,
    figsize=(10, 5),
    plot_budding_pts=True,
    shade_times=None,
    shade_colour="gold",
):
    """
    Wrapper of plot_lineage to plot a range of idx values.

    Parameters
    ----------
    irange: list of int
        Indices of dataloader ids to plot.
    dl: dataloader instance
        Contains data to plot.

    Example
    -------
    >>> plot_lineages(arange(1,20), dl, ["flavin"])
    """
    if isinstance(irange, int):
        irange = [irange]
    for i in irange:
        plot_lineage(
            dl.ids[i],
            dl.df,
            signals,
            show,
            figsize,
            plot_budding_pts,
            shade_times,
            shade_colour,
        )


def plot_lineage(
    idx,
    df,
    signals=["volume", "growth_rate"],
    show=True,
    figsize=(10, 5),
    plot_budding_pts=True,
    shade_times=None,
    shade_colour="gold",
):
    """
    Plot the signals for one cell lineage.

    If "growth_rate" or "volume" is a signal, plots the signal for the
    mother and the different buds.

    Arguments
    ---------
    idx: integer
        One of df.ids
    df: dataframe
        Typically dl.df or a sub dataframe.
    signals: string or list of strings
        Signals to plot.
    show: boolean
        If True, display figure.
    figsize: tuple of two floats
        The size of figure, eg (10, 5)
    plot_budding_pts: boolean
        If True, plot births as black dots.
    shade_times: list of tuples of two floats, optional
        Add vertical shading between each pair of times for all subplots.
    shade_colour: string, optional
        Colour for vertical shading.

    Examples
    --------
    >>> from wela.plotting import plot_lineage
    >>> plot_lineage(dl.df.id[3], dl.df)
    >>> plot_lineage(dl.df.id[3], dl.df, "median_GFP")

    and for both mother and bud volumes

    >>> plot_lineage(dl.df.id[3], dl.df, "volume")

    as well as

    >>> plot_lineage(dl.ids[23], dl.df,
            ["flavin", "total_mCherry", "bud_volume"],
            shade_times=[(7, 15)])
    """
    if idx not in df.id.unique():
        raise Exception("idx not part of dataframe")
    if isinstance(signals, str):
        signals = [signals]
    nosubplots = len(signals)
    # show buddings if possible
    if "buddings" in df.columns:
        buddings = df[df.id == idx]["buddings"].to_numpy()
        b_pts = np.where(buddings)[0]
    if len(b_pts) == 1:
        nb_pts = np.concatenate((b_pts, [len(buddings) - 1]))
    else:
        nb_pts = b_pts
    # find time
    t = df[df.id == idx]["time"].to_numpy()
    # generate figure
    fig = plt.figure(figsize=figsize)
    # index for subplot
    splt = 1
    plt.suptitle(idx)
    # plot signals
    for signal in signals:
        if "bud_" + signal in df.columns:
            bud_sig = df[df.id == idx]["bud_" + signal].to_numpy()
            if signal == "growth_rate" and "growth_rate" not in df.columns:
                signal = "mother_" + signal
            m_sig = df[df.id == idx][signal].to_numpy()
            plt.subplot(nosubplots, 1, splt)
            for start, end in zip(nb_pts, nb_pts[1:]):
                # plot bud signal
                plt.plot(t[start:end], bud_sig[start:end], ".-")
            # plot mother signal
            plt.plot(t, m_sig, "k-")
            plt.ylabel(signal)
            add_shading(shade_times, shade_colour)
            plt.grid()
            splt += 1
        else:
            m_sig = df[df.id == idx][signal].to_numpy()
            plt.subplot(nosubplots, 1, splt)
            plt.plot(t, m_sig)
            plt.ylabel(signal)
            add_shading(shade_times, shade_colour)
            plt.grid()
            splt += 1
        # plot buddings
        if (
            plot_budding_pts
            and ("buddings" in df.columns)
            and signal not in ["volume", "growth_rate"]
        ):
            for bpt in b_pts:
                plt.plot(t[bpt], m_sig[bpt], "k.")
    plt.xlabel("time (hours)")
    # share the x axis
    ax_list = fig.axes
    ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list)
    if show:
        plt.show(block=False)


def add_shading(shade_times, shade_colour):
    """Shade vertically between each pair of shade_times."""
    if shade_times is not None:
        for tstart, tend in shade_times:
            plt.axvspan(tstart, tend, facecolor=shade_colour, alpha=0.1)


def plot_replicate_array(
    data,
    t=None,
    plotmean=True,
    xlabel=None,
    ylabel=None,
    title=None,
    grid=True,
    showdots=False,
    show=True,
):
    """
    Plot summary statistics versus axis 1 (time) for an array of replicates.

    Parameters
    ----------
    data: array
        An array of signal values, with each row a replicate measurement
        and each column a time point.
    t : array (optional)
        An array of time points.
    plotmean: boolean
        If True, plot the mean correlation over replicates versus the lag
        with the standard error.
        If False, plot the median correlation and the interquartile range.
    xlabel: string
        Label for x-axis.
    ylabel: string
        Label for y-axis.
    title: string
        Title for plot.
    grid: boolean
        If True, draw grid on plot.
    showdots: boolean
        If True, show individual data points.
    show: boolean
        If True, display figure immediately.
    """
    # number of time points
    n = data.shape[1]
    # number of replicates
    nr = data.shape[0]
    if not np.any(t):
        t = np.arange(n)
    if showdots:
        plt_type = "b.-"
    else:
        plt_type = "b-"
    if show:
        plt.figure()
    if plotmean:
        # mean and standard error
        plt.plot(t, np.nanmean(data, axis=0), plt_type)
        stderr = np.nanstd(data, axis=0) / np.sqrt(nr)
        plt.fill_between(
            t,
            np.nanmean(data, axis=0) + stderr,
            np.nanmean(data, axis=0) - stderr,
            color="b",
            alpha=0.2,
        )
    else:
        # median and interquartile range
        plt.plot(t, np.nanmedian(data, axis=0), plt_type)
        plt.fill_between(
            t,
            np.nanquantile(data, 0.25, axis=0),
            np.nanquantile(data, 0.75, axis=0),
            color="b",
            alpha=0.2,
        )
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)
    if title:
        plt.title(title)
    if grid:
        plt.grid()
    if show:
        plt.show(block=False)


def plot2Dhist(
    x,
    y,
    bins=[20, 20],
    title=None,
    figsize=None,
    xlabel="time",
    ylabel=None,
    xmax=None,
    ymax=None,
    cmap=None,
    **kwargs,
):
    """
    Plot two dimensional histograms.

    Typically, time is on the x-axis; the support of the distribution is on
    the y-axis; and shading represents the height of the distribution at each
    y value.

    Parameters
    ----------
    x: 1D or 2D array
        If 1D, we assume an array of time points.
    y: 2D array
        Each row contains time-series data from a single cell.
    bins: list of arrays (optional)
        Specifies the bins, either explicitly or as a number of bins, and can
        be used to plot different data sets on axes with the same range.
    title: str (optional)
        Title for the plot.
    figsize: tuple (optional)
        Sets the width and height of the figure.
    xlabel: str (optional)
        Label for the x-axis.
    ylabel: str (optional)
        Label for the y-axis.
    xmax: float (optional)
        The maximal value on the x-axis.
    ymax: float (optional)
        The maximal value on the y-axis.
    cmap: matplotlib.colormap (optional)
        Color map for the shading.
    **kwargs:
        Passed to plt.pcolormesh.

    Returns
    -------
    edges: tuple of 1D arrays
        The x and y bin edges.
    h: 2D array
        The number of data points in each bin.

    Examples
    --------
    Load data:
    >>> from wela.dataloader import dataloader
    >>> from wela.plotting import plothist
    >>> dl = dataloader()
    >>> dl.load("1334_2023_03_28_pyrTo2Gal_01glc_00", use_tsv=True)
    >>> dlc = dataloader()
    >>> dlc.load("1005_2023_03_09_exp_00_pyrToGalInduction, use_tsv="True")

    Get time-series data:
    >>> t, d = dl.get_time_series("median_GFP")
    >>> tc, dc = dlc.get_time_series("median_GFP")

    Plot both data sets using axes with the same range:
    >>> bins = plothist(t, dc, title="2% Gal", figsize=(4, 3))[0]
    >>> plothist(t, d, title="2% Gal and 0.1% Glu", bins=bins, figsize=(4, 3))
    """
    if x.ndim == 1:
        # make into a 2D array
        xa = np.matlib.repmat(x, y.shape[0], 1)
    else:
        xa = x
    # find real data
    select = ~np.isnan(xa) & ~np.isnan(y)
    xn = xa[select].flatten()
    yn = y[select].flatten()
    # bin
    h, xedges, yedges = np.histogram2d(xn, yn, bins=bins)
    # make histogram
    if cmap is None:
        cmap = copy(plt.cm.viridis)
    else:
        cmap = copy(cmap)
    cmap.set_bad(cmap(0))
    # plot using pcolormesh
    if figsize is not None:
        plt.figure(figsize=figsize)
    else:
        plt.figure()
    pcm = plt.pcolormesh(
        xedges,
        yedges,
        h.T,
        cmap=cmap,
        rasterized=True,
        **kwargs,
    )
    plt.colorbar(pcm, label="number of cells", pad=0.02)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if ymax is not None:
        plt.ylim(top=ymax)
    if xmax is not None:
        plt.xlim(right=xmax)
    if title:
        plt.title(title)
    plt.tight_layout()
    plt.show(block=False)
    return (xedges, yedges), h


def plot_cuml_divisions_per_cell(t, buddings, nboots=30, col="b", label=None):
    """
    Plot the mean cumulative number of divisions per cell over time.

    No figure is automatically generated. See example.

    Parameters
    ----------
    t: 1D array
        An array of time points.
    buddings: 2D array
        A binary array where a budding event is denoted by a one and each row is
        a time series of budding events for one cell.
    nboots: integer (optional)
        The number of bootstraps to use to estimate errors.
    col: str (optional)
        Color of the line to plot.
    label: str (optional)
        Label for the legend.

    Example
    -------
    >>> import matplotlib.pylab as plt
    >>> from wela.plotting import plot_cuml_divisions_per_cell
    >>> plt.figure()
    >>> plot_cuml_divisions_per_cell(t, b, label="Gal Glu")
    >>> plt.legend()
    >>> plt.show(block=False)
    """

    def find_cuml(b):
        return (
            np.array([np.nansum(b[:, :i]) for i in range(b.shape[1])])
            / b.shape[0]
        )

    def sample(b):
        return np.array(
            [b[i, :] for i in np.random.randint(0, b.shape[0], b.shape[0])]
        )

    cuml = find_cuml(buddings)
    err_cuml = 2 * np.std(
        [find_cuml(sample(buddings)) for i in range(nboots)], axis=0
    )
    plt.plot(t, cuml, color=col, label=label)
    plt.fill_between(t, cuml - err_cuml, cuml + err_cuml, color=col, alpha=0.2)