import matplotlib.pylab as plt
import matplotlib.cm
import seaborn as sns
import numpy as np
import genutils as gu


def kymograph(
    df,
    hue="median_GFP",
    x="time",
    y="id",
    xtick_step=4,
    vmax=None,
    vmin=None,
    cmap=matplotlib.cm.Greens,
    figsize=(6, 10),
    title=None,
    returnfig=False,
):
    """
    Plot a heatmap.

    Typically each row is a single cell and the x-axis shows time.
    """
    if hue == "births":
        cmap = "Greys"
    elif "growth_rate" in hue:
        cmap = sns.color_palette("magma", as_cmap=True)
    wdf = df.pivot(y, x, hue)
    dt = np.min(np.diff(np.sort(df.time.unique())))
    # from Arin
    data = wdf.to_numpy()
    # define horizontal axis ticks and labels
    xtick_min = 0
    xtick_max = dt / 60 * data.shape[1]
    xticklabels = np.arange(xtick_min, xtick_max, xtick_step)
    xticks = [
        int(np.where((dt / 60 * np.arange(data.shape[1])) == label)[0].item())
        for label in xticklabels
    ]
    xticklabels = list(map(str, xticklabels.tolist()))
    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticklabels)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_title(title)
    trace_heatmap = ax.imshow(
        data,
        cmap=cmap,
        interpolation="none",
        vmax=vmax,
        vmin=vmin,
    )
    ax.figure.colorbar(
        mappable=trace_heatmap,
        ax=ax,
        label=hue,
        aspect=50,
        shrink=0.5,
    )
    plt.tight_layout()
    plt.show()
    if returnfig:
        return fig, ax


def plot_random_time_series(time, values, signalname=None, number=5):
    """Plot random time series on mouse click and terminates on a key press."""
    fig = plt.figure()
    go = True

    def no_go(event):
        nonlocal go
        go = False

    # stop while loop if there is key press event
    fig.canvas.mpl_connect("key_press_event", no_go)
    # loop until a key is pressed
    while go:
        irnds = np.random.randint(0, values.shape[0], number)
        plt.cla()
        for i in irnds:
            plt.plot(time / 60, values[i, :], label=str(i))
        plt.xlabel("time (hours)")
        if signalname:
            plt.ylabel(signalname)
        plt.legend()
        plt.grid()
        plt.draw()
        while plt.waitforbuttonpress(0.2) is None:
            # wait for a key or mouse button
            # if a key, no_go called through mpl_connect
            pass
        print(".")


def plot_lineage(
    idx,
    df,
    signals=["volume", "growth_rate"],
    show=True,
    figsize=(10, 5),
    cyto_pts_signal=None,
    plot_budding_pts=True,
    plot_G1=False,
):
    """
    Plot the signals for one cell lineage.

    If "growth_rate" or "volume" is a signal, plots the signal for the
    mother and the different buds.

    Arguments
    ---------
    idx: integer
        One of df.ids
    df: dataframe
        Typically dl.df or a sub dataframe.
    signals: string or list of strings
        Signals to plot.
    show: boolean
        If True, display figure.
    figsize: tuple of two floats
        The size of figure, eg (10, 5)
    plot_budding_pts: boolean
        If True, plot births as black dots.
    plot_cyto_pts: boolean
        If True, plot estimated points of cytokinesis as purple dots.
    plot_G1: boolean
        If True, indicate where the mother cell is in G1
    """
    if idx not in df.id.unique():
        raise Exception("idx not part of dataframe")
    signals = gu.makelist(signals)
    nosubplots = len(signals)
    # show buddings if possible
    if "buddings" in df.columns:
        buddings = df[df.id == idx]["buddings"].to_numpy()
        b_pts = np.where(buddings)[0]
    if "births" in df.columns:
        buddings = df[df.id == idx]["births"].to_numpy()
        b_pts = np.where(buddings)[0]
    if len(b_pts) == 1:
        nb_pts = np.concatenate((b_pts, [len(buddings) - 1]))
    else:
        nb_pts = b_pts
    # show cytokinesis point if possible
    if cyto_pts_signal and cyto_pts_signal in df.columns:
        cyto = df[df.id == idx][cyto_pts_signal].to_numpy()
        cyto_pts = np.where(cyto)[0]
    # show G1 points if possible
    if "G1" in df.columns:
        G1 = df[df.id == idx]["G1"].to_numpy()
        G1 = G1[~np.isnan(G1)]
        if np.any(G1):
            g1_pts = np.where(G1)[0]
    # find time
    t = df[df.id == idx]["time"].to_numpy()
    t = t / 60
    # generate figure
    fig = plt.figure(figsize=figsize)
    # index for subplot
    splt = 1
    plt.suptitle(idx)
    # plot signals
    for signal in signals:
        if "volume" in signal:
            s = df[df.id == idx]["bud_volume"].to_numpy()
            mvol = df[df.id == idx]["volume"].to_numpy()
            plt.subplot(nosubplots, 1, splt)
            for start, end in zip(nb_pts, nb_pts[1:]):
                # plot bud volume
                plt.plot(t[start:end], s[start:end], ".-")
            # plot mother volume
            plt.plot(t, mvol, "k-")
            plt.ylabel("volume")
            plt.grid()
            splt += 1
            # for showing G1
            g1s = mvol
            g1col = "k"
        elif signal == "growth_rate":
            s = df[df.id == idx]["bud_growth_rate"].to_numpy()
            mgr = df[df.id == idx]["mother_growth_rate"].to_numpy()
            plt.subplot(nosubplots, 1, splt)
            for start, end in zip(nb_pts, nb_pts[1:]):
                # plot bud growth rate
                plt.plot(t[start:end], s[start:end], ".-")
            # plot mother growth rate
            plt.plot(t, mgr, "k-")
            plt.ylabel("growth rate")
            plt.grid()
            splt += 1
            # for showing G1
            g1s = mgr
            g1col = "k"
        else:
            s = df[df.id == idx][signal].to_numpy()
            plt.subplot(nosubplots, 1, splt)
            p = plt.plot(t, s)
            plt.ylabel(signal)
            plt.grid()
            splt += 1
            # for showing G1
            g1s = s
            g1col = p[0].get_color()
        # plot buddings
        if (
            plot_budding_pts
            and ("births" in df.columns or "buddings" in df.columns)
            and signal not in ["volume", "growth_rate"]
        ):
            for bpt in b_pts:
                plt.plot(t[bpt], s[bpt], "k.")
        # plot point of cytokinesis
        if cyto_pts_signal and cyto_pts_signal in df.columns:
            for cpt in cyto_pts:
                plt.plot(t[cpt], s[cpt], "mo")
        # plot G1
        if plot_G1 and "G1" in df.columns and np.any(G1):
            plt.plot(t[g1_pts], g1s[g1_pts], "x", color=g1col)
    plt.xlabel("time (hours)")
    # share the x axis
    ax_list = fig.axes
    ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list)
    if show:
        plt.show()


def plot_replicate_array(
    data,
    t=None,
    plotmean=True,
    xlabel=None,
    ylabel=None,
    title=None,
    grid=True,
    showdots=False,
    show=True,
):
    """
    Plot summary statistics versus axis 1 (time) for an array of replicates.

    Parameters
    ----------
    data: array
        An array of signal values, with each row a replicate measurement
        and each column a time point.
    t : array (optional)
        An array of time points.
    plotmean: boolean
        If True, plot the mean correlation over replicates versus the lag
        with the standard error.
        If False, plot the median correlation and the interquartile range.
    xlabel: string
        Label for x-axis.
    ylabel: string
        Label for y-axis.
    title: string
        Title for plot.
    grid: boolean
        If True, draw grid on plot.
    showdots: boolean
        If True, show individual data points.
    show: boolean
        If True, display figure immediately.
    """
    # number of time points
    n = data.shape[1]
    # number of replicates
    nr = data.shape[0]
    if not np.any(t):
        t = np.arange(n)
    if showdots:
        plt_type = "b.-"
    else:
        plt_type = "b-"
    if show:
        plt.figure()
    if plotmean:
        # mean and standard error
        plt.plot(t, np.nanmean(data, axis=0), plt_type)
        stderr = np.nanstd(data, axis=0) / np.sqrt(nr)
        plt.fill_between(
            t,
            np.nanmean(data, axis=0) + stderr,
            np.nanmean(data, axis=0) - stderr,
            color="b",
            alpha=0.2,
        )
    else:
        # median and interquartile range
        plt.plot(t, np.nanmedian(data, axis=0), plt_type)
        plt.fill_between(
            t,
            np.nanquantile(data, 0.25, axis=0),
            np.nanquantile(data, 0.75, axis=0),
            color="b",
            alpha=0.2,
        )
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)
    if title:
        plt.title(title)
    if grid:
        plt.grid()
    if show:
        plt.show()