from copy import copy import matplotlib.cm import matplotlib.pylab as plt import numpy as np import numpy.matlib def kymograph( df, hue="median_GFP", x="time", y="id", xtick_step_in_hours=5, vmax=None, vmin=None, cmap=matplotlib.cm.Greens, figsize=(6, 10), title=None, returnfig=False, ): """ Plot a heatmap. Typically each row is a single cell and the x-axis shows time. Time is assumed to be in hours. Examples -------- >>> from wela.plotting import kymograph >>> kymograph(dl.df, hue="median_GFP") >>> kymograph(dl.df, hue="bud_volume", title="2% Gal") >>> kymograph(dl.df, hue="buddings") """ if hue == "buddings": cmap = "Greys" wdf = df.pivot(y, x, hue) dt = np.min(np.diff(np.sort(df.time.unique()))) # from Arin data = wdf.to_numpy() # define horizontal axis ticks and labels xtick_min = 0 xtick_max = dt * data.shape[1] xticklabels = np.arange(xtick_min, xtick_max, xtick_step_in_hours) xticks = [ np.argmin((dt * np.arange(data.shape[1]) - label) ** 2) for label in xticklabels ] xticklabels = list(map(str, xticklabels.tolist())) # plot fig, ax = plt.subplots(figsize=figsize) ax.set_xticks(xticks) ax.set_xticklabels(xticklabels) ax.set_xlabel(x) ax.set_ylabel(y) ax.set_title(title) trace_heatmap = ax.imshow( data, cmap=cmap, interpolation="none", vmax=vmax, vmin=vmin, ) ax.figure.colorbar( mappable=trace_heatmap, ax=ax, label=hue, aspect=50, shrink=0.5, ) plt.tight_layout() plt.show(block=False) if returnfig: return fig, ax def plot_random_time_series(time, values, signalname=None, number=5): """Plot random time series on mouse click and terminates on a key press.""" fig = plt.figure() go = True def no_go(event): nonlocal go go = False # stop while loop if there is key press event fig.canvas.mpl_connect("key_press_event", no_go) # loop until a key is pressed while go: irnds = np.random.randint(0, values.shape[0], number) plt.cla() for i in irnds: plt.plot(time / 60, values[i, :], label=str(i)) plt.xlabel("time (hours)") if signalname: plt.ylabel(signalname) plt.legend() plt.grid() plt.draw() while plt.waitforbuttonpress(0.2) is None: # wait for a key or mouse button # if a key, no_go called through mpl_connect pass print(".") def plot_lineages( irange, dl, signals=["volume", "growth_rate"], show=True, figsize=(10, 5), plot_budding_pts=True, shade_times=None, shade_colour="gold", ): """ Wrapper of plot_lineage to plot a range of idx values. Parameters ---------- irange: list of int Indices of dataloader ids to plot. dl: dataloader instance Contains data to plot. Example ------- >>> plot_lineages(arange(1,20), dl, ["flavin"]) """ if isinstance(irange, int): irange = [irange] for i in irange: plot_lineage( dl.ids[i], dl.df, signals, show, figsize, plot_budding_pts, shade_times, shade_colour, ) def plot_lineage( idx, df, signals=["volume", "growth_rate"], show=True, figsize=(10, 5), plot_budding_pts=True, shade_times=None, shade_colour="gold", ): """ Plot the signals for one cell lineage. If "growth_rate" or "volume" is a signal, plots the signal for the mother and the different buds. Arguments --------- idx: integer One of df.ids df: dataframe Typically dl.df or a sub dataframe. signals: string or list of strings Signals to plot. show: boolean If True, display figure. figsize: tuple of two floats The size of figure, eg (10, 5) plot_budding_pts: boolean If True, plot births as black dots. shade_times: list of tuples of two floats, optional Add vertical shading between each pair of times for all subplots. shade_colour: string, optional Colour for vertical shading. Examples -------- >>> from wela.plotting import plot_lineage >>> plot_lineage(dl.df.id[3], dl.df) >>> plot_lineage(dl.df.id[3], dl.df, "median_GFP") and for both mother and bud volumes >>> plot_lineage(dl.df.id[3], dl.df, "volume") as well as >>> plot_lineage(dl.ids[23], dl.df, ["flavin", "total_mCherry", "bud_volume"], shade_times=[(7, 15)]) """ if idx not in df.id.unique(): raise Exception("idx not part of dataframe") if isinstance(signals, str): signals = [signals] nosubplots = len(signals) # show buddings if possible if "buddings" in df.columns: buddings = df[df.id == idx]["buddings"].to_numpy() b_pts = np.where(buddings)[0] if len(b_pts) == 1: nb_pts = np.concatenate((b_pts, [len(buddings) - 1])) else: nb_pts = b_pts # find time t = df[df.id == idx]["time"].to_numpy() # generate figure fig = plt.figure(figsize=figsize) # index for subplot splt = 1 plt.suptitle(idx) # plot signals for signal in signals: if "bud_" + signal in df.columns: bud_sig = df[df.id == idx]["bud_" + signal].to_numpy() if signal == "growth_rate" and "growth_rate" not in df.columns: signal = "mother_" + signal m_sig = df[df.id == idx][signal].to_numpy() plt.subplot(nosubplots, 1, splt) for start, end in zip(nb_pts, nb_pts[1:]): # plot bud signal plt.plot(t[start:end], bud_sig[start:end], ".-") # plot mother signal plt.plot(t, m_sig, "k-") plt.ylabel(signal) add_shading(shade_times, shade_colour) plt.grid() splt += 1 else: m_sig = df[df.id == idx][signal].to_numpy() plt.subplot(nosubplots, 1, splt) plt.plot(t, m_sig) plt.ylabel(signal) add_shading(shade_times, shade_colour) plt.grid() splt += 1 # plot buddings if ( plot_budding_pts and ("buddings" in df.columns) and signal not in ["volume", "growth_rate"] ): for bpt in b_pts: plt.plot(t[bpt], m_sig[bpt], "k.") plt.xlabel("time (hours)") # share the x axis ax_list = fig.axes ax_list[0].get_shared_x_axes().join(ax_list[0], *ax_list) if show: plt.show(block=False) def add_shading(shade_times, shade_colour): """Shade vertically between each pair of shade_times.""" if shade_times is not None: for tstart, tend in shade_times: plt.axvspan(tstart, tend, facecolor=shade_colour, alpha=0.1) def plot_replicate_array( data, t=None, plotmean=True, xlabel=None, ylabel=None, title=None, grid=True, showdots=False, show=True, ): """ Plot summary statistics versus axis 1 (time) for an array of replicates. Parameters ---------- data: array An array of signal values, with each row a replicate measurement and each column a time point. t : array (optional) An array of time points. plotmean: boolean If True, plot the mean correlation over replicates versus the lag with the standard error. If False, plot the median correlation and the interquartile range. xlabel: string Label for x-axis. ylabel: string Label for y-axis. title: string Title for plot. grid: boolean If True, draw grid on plot. showdots: boolean If True, show individual data points. show: boolean If True, display figure immediately. """ # number of time points n = data.shape[1] # number of replicates nr = data.shape[0] if not np.any(t): t = np.arange(n) if showdots: plt_type = "b.-" else: plt_type = "b-" if show: plt.figure() if plotmean: # mean and standard error plt.plot(t, np.nanmean(data, axis=0), plt_type) stderr = np.nanstd(data, axis=0) / np.sqrt(nr) plt.fill_between( t, np.nanmean(data, axis=0) + stderr, np.nanmean(data, axis=0) - stderr, color="b", alpha=0.2, ) else: # median and interquartile range plt.plot(t, np.nanmedian(data, axis=0), plt_type) plt.fill_between( t, np.nanquantile(data, 0.25, axis=0), np.nanquantile(data, 0.75, axis=0), color="b", alpha=0.2, ) if xlabel: plt.xlabel(xlabel) if ylabel: plt.ylabel(ylabel) if title: plt.title(title) if grid: plt.grid() if show: plt.show(block=False) def plot2Dhist( x, y, bins=[20, 20], title=None, figsize=None, xlabel="time", ylabel=None, xmax=None, ymax=None, cmap=None, **kwargs, ): """ Plot two dimensional histograms. Typically, time is on the x-axis; the support of the distribution is on the y-axis; and shading represents the height of the distribution at each y value. Parameters ---------- x: 1D or 2D array If 1D, we assume an array of time points. y: 2D array Each row contains time-series data from a single cell. bins: list of arrays (optional) Specifies the bins, either explicitly or as a number of bins, and can be used to plot different data sets on axes with the same range. title: str (optional) Title for the plot. figsize: tuple (optional) Sets the width and height of the figure. xlabel: str (optional) Label for the x-axis. ylabel: str (optional) Label for the y-axis. xmax: float (optional) The maximal value on the x-axis. ymax: float (optional) The maximal value on the y-axis. cmap: matplotlib.colormap (optional) Color map for the shading. **kwargs: Passed to plt.pcolormesh. Returns ------- edges: tuple of 1D arrays The x and y bin edges. h: 2D array The number of data points in each bin. Examples -------- Load data: >>> from wela.dataloader import dataloader >>> from wela.plotting import plothist >>> dl = dataloader() >>> dl.load("1334_2023_03_28_pyrTo2Gal_01glc_00", use_tsv=True) >>> dlc = dataloader() >>> dlc.load("1005_2023_03_09_exp_00_pyrToGalInduction, use_tsv="True") Get time-series data: >>> t, d = dl.get_time_series("median_GFP") >>> tc, dc = dlc.get_time_series("median_GFP") Plot both data sets using axes with the same range: >>> bins = plothist(t, dc, title="2% Gal", figsize=(4, 3))[0] >>> plothist(t, d, title="2% Gal and 0.1% Glu", bins=bins, figsize=(4, 3)) """ if x.ndim == 1: # make into a 2D array xa = np.matlib.repmat(x, y.shape[0], 1) else: xa = x # find real data select = ~np.isnan(xa) & ~np.isnan(y) xn = xa[select].flatten() yn = y[select].flatten() # bin h, xedges, yedges = np.histogram2d(xn, yn, bins=bins) # make histogram if cmap is None: cmap = copy(plt.cm.viridis) else: cmap = copy(cmap) cmap.set_bad(cmap(0)) # plot using pcolormesh if figsize is not None: plt.figure(figsize=figsize) else: plt.figure() pcm = plt.pcolormesh( xedges, yedges, h.T, cmap=cmap, rasterized=True, **kwargs, ) plt.colorbar(pcm, label="number of cells", pad=0.02) plt.xlabel(xlabel) plt.ylabel(ylabel) if ymax is not None: plt.ylim(top=ymax) if xmax is not None: plt.xlim(right=xmax) if title: plt.title(title) plt.tight_layout() plt.show(block=False) return (xedges, yedges), h def plot_cuml_divisions_per_cell(t, buddings, nboots=30, col="b", label=None): """ Plot the mean cumulative number of divisions per cell over time. No figure is automatically generated. See example. Parameters ---------- t: 1D array An array of time points. buddings: 2D array A binary array where a budding event is denoted by a one and each row is a time series of budding events for one cell. nboots: integer (optional) The number of bootstraps to use to estimate errors. col: str (optional) Color of the line to plot. label: str (optional) Label for the legend. Example ------- >>> import matplotlib.pylab as plt >>> from wela.plotting import plot_cuml_divisions_per_cell >>> plt.figure() >>> plot_cuml_divisions_per_cell(t, b, label="Gal Glu") >>> plt.legend() >>> plt.show(block=False) """ def find_cuml(b): return ( np.array([np.nansum(b[:, :i]) for i in range(b.shape[1])]) / b.shape[0] ) def sample(b): return np.array( [b[i, :] for i in np.random.randint(0, b.shape[0], b.shape[0])] ) cuml = find_cuml(buddings) err_cuml = 2 * np.std( [find_cuml(sample(buddings)) for i in range(nboots)], axis=0 ) plt.plot(t, cuml, color=col, label=label) plt.fill_between(t, cuml - err_cuml, cuml + err_cuml, color=col, alpha=0.2)