diff --git a/src/postprocessor/compiler.py b/src/postprocessor/compiler.py deleted file mode 100644 index d6c3275a6ead1edd7effe2f11f8495e45117c4e1..0000000000000000000000000000000000000000 --- a/src/postprocessor/compiler.py +++ /dev/null @@ -1,906 +0,0 @@ -""" -Script in development -""" - -# /usr/bin/env python3 -import re -import warnings -from abc import abstractmethod -from collections import Counter -from pathlib import Path -from typing import Dict, Iterable, Tuple, Union - -import h5py -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns -from agora.abc import ProcessABC -from matplotlib.backends.backend_pdf import PdfPages -from numpy import ndarray -from scipy.signal import find_peaks - -from postprocessor.grouper import NameGrouper - -sns.set_style("darkgrid") - - -# Main dataframe structure - -# | position | group | ntraps |robustness index | initial_ncells | final_ncells -# dir = "/home/alan/Documents/dev/skeletons/data/2021_06_15_pypipeline_unit_test_00/2021_06_15_pypipeline_unit_test_00/" -# dir = "/home/alan/Documents/dev/libs/aliby/data/2021_08_24_2Raf_00/2021_08_24_2Raf_00/" -# dirs = [ -# "16543_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_01", -# "16545_2019_07_16_aggregates_CTP_switch_2_0glu_0_0glu_URA7young_URA8young_URA8old_secondRun_01", -# "18069_2019_12_05_aggregates_updownshift_2_0_2_URA8_URA7H360A_URA7H360R_00", -# "18616_2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01", -# "18617_2020_02_21_protAgg_downUpShift_2_0_2_pHluorin_Ura7HA_Ura7HR_00", -# "19129_2020_09_06_DownUpshift_2_0_2_glu_ura_mig1msn2_phluorin_00", -# "19144_2020_09_07_DownUpshift_2_0_2_glu_ura_mig1msn2_phluorin_secondRound_00", -# "19169_2020_09_09_downUpshift_2_0_2_glu_ura8_phl_mig1_phl_msn2_03", -# "19199_2020_09_29_downUpshift_2_0_2_glu_ura8_ura8h360a_ura8h360r_00", -# "19203_2020_09_30_downUpshift_twice_2_0_2_glu_ura8_ura8h360a_ura8h360r_00", -# "19207_2020_10_01_exp_00", -# "19232_2020_10_02_downUpshift_twice_2_0_2_glu_ura8_phluorinMsn2_phluorinMig1_01", -# "19307_2020_10_22_downUpshift_2_01_2_glucose_dual_pH__dot6_nrg1_tod6__00", -# "19310_2020_10_22_downUpshift_2_0_2_glu_dual_phluorin__glt1_psa1_ura7__thrice_00", -# "19311_2020_10_23_downUpshift_2_0_2_glu_dual_phluorin__glt1_psa1_ura7__twice__04", -# "19328_2020_10_31_downUpshift_four_2_0_2_glu_dual_phl__glt1_ura8_ura8__00", -# "19329_2020_11_01_exp_00", -# "19333_2020_11_02_downUpshift_2_0_2_glu_ura7_ura7ha_ura7hr_00", -# "19334_2020_11_02_downUpshift_2_0_2_glu_ura8_ura8ha_ura8hr_00", -# "19447_2020_11_18_downUpshift_2_0_2_glu_gcd2_gcd6_gcd7__02", -# "19810_2021_02_21_ToxicityTest_00", -# "19993_2021_06_15_pypipeline_unit_test_00", -# "19996_2021_06_27_ph_calibration_dual_phl_ura8_5_04_5_83_7_69_7_13_6_59__01", -# "20419_2021_11_02_dose_response_raf_05_075_2_glu_005_2_constantMedia_00", -# ] -# outdir = "/home/alan/Documents/dev/skeletons/data" -# dirs = Path(outdir).glob("*ph*") - - -# from abc import abstractclassmethod, abstractmethod - - -# group_pos_trap_ncells = ( -# concat.dropna().groupby(["group", "position", "trap"]).apply(len) -# ) -# group_pos_trapswcell = ( -# group_pos_trap_ncells.dropna().groupby(["group", "position"]).apply(len) -# ) - - -class Meta: - """Convenience class to fetch data from hdf5 file.""" - - def __init__(self, filename): - self.filename = filename - - @property - def ntimepoints(self): - with h5py.File(self.filename, "r") as f: - return f.attrs["time_settings/ntimepoints"][0] - - -class Compiler(ProcessABC): - # def __init__(self, parameters): - # super().__init__(parameters) - - @abstractmethod - def load_data(self): - """Abstract function that must be reimplemented.""" - pass - - @abstractmethod - def run(): - pass - - -class ExperimentCompiler(Compiler): - def __init__(self, CompilerParameters, exp_path: Path): - super().__init__(CompilerParameters) - self.load_data(exp_path) - - def run(self): - return { - method: getattr(self, "compile_" + method)() - for method in ( - "slice", - "slices", - "delta_traps", - "pertrap_metric", - "ncells", - "last_valid_tp", - "stages_dmetric", - "fluorescence", - ) - } - - def load_data(self, path: Path): - self.grouper = NameGrouper(path) - self.meta = Meta(self.grouper.files[0]) - - @property - def ntraps(self) -> dict: - """Get the number of traps in each position. - - Returns ------- dict str -> int Examples -------- FIXME: Add - docs. - """ - return { - pos: coords.shape[0] - for pos, coords in self.grouper.traplocs().items() - } - - def concat_signal(self, sigloc=None, mode=None, **kwargs) -> pd.DataFrame: - if sigloc is None: - sigloc = "extraction/general/None/volume" - self.sigloc = sigloc - - if mode is None: - mode = "retained" - - if not hasattr(self, "_concat") or self.sigloc != sigloc: - self._concat = self.grouper.concat_signal( - self.sigloc, mode=mode, **kwargs - ) - - return self._concat - - def get_tp(self, sigloc=None, tp=None, mode=None, **kwargs) -> pd.Series: - if tp is None: - tp = 0 - - if mode is None: - mode = True - - return self.concat_signal(sigloc=sigloc, mode=mode, **kwargs).iloc[ - :, tp - ] - - def count_cells( - self, - signal="extraction/general/None/volume", - mode="raw", - **kwargs, - ): - df = self.grouper.concat_signal(signal, mode=mode, **kwargs) - df = df.groupby(["group", "position", "trap"]).count() - df[df == 0] = np.nan - return df - - def compile_dmetrics(self, stages=None): - """Generate dataframe with dVol metrics without major cell picking.""" - names_signals = { - "dvol": "postprocessing/dsignal/postprocessing_savgol_extraction_general_None_volume", - "bud_dvol": "postprocessing/bud_metric/postprocessing_dsignal_postprocessing_savgol_extraction_general_None_volume", - } - names_signals = { - "dvol": "postprocessing/dsignal/postprocessing_savgol_extraction_general_None_volume", - "bud_dvol": "postprocessing/bud_metric/postprocessing_dsignal_postprocessing_savgol_extraction_general_None_volume", - "buddings": "postprocessing/buddings/extraction_general_None_volume", - } - operations = { - "dvol": ("dvol", "max"), - "bud_dvol": ("bud_dvol", "max"), - "buddings": ("buddings", "sum"), - "buddings_mean": ("buddings", "mean"), - } - - input_signals = { - k: self.grouper.concat_signal(v) for k, v in names_signals.items() - } - - ids = input_signals["buddings"].index - for v in input_signals.values(): - ids = ids.intersection(v.index) - - if stages: - - def process_dfs(dfs, rng): - return pd.DataFrame( - { - k: getattr(dfs[sig].loc(axis=1)[rng].loc[ids], op)( - axis=1 - ) - if isinstance(op, str) - else dfs[sig].loc[ids].apply(op, axis=1) - for k, (sig, op) in operations.items() - } - ) - - # Note that all input_signals columns must be the same - col_vals = list(input_signals.values())[0].columns - stages_dfs = {"Full": process_dfs(input_signals, col_vals)} - for k, rng in stages: - stage_df = process_dfs(input_signals, col_vals[rng]) - stages_dfs[k] = stage_df - - concat = pd.concat([x.reset_index() for x in stages_dfs.values()]) - concat["stage"] = np.array( - [ - np.repeat(x, len(concat) // len(stages_dfs)) - for x in stages_dfs.keys() - ] - ).flatten() - - return ( - concat.set_index(["group", "position", "trap", "cell_label"]) - .melt("stage", ignore_index=False, var_name="growth_metric") - .reset_index() - ) - - def compile_stages_dmetric(self): - stages = self.get_stages() - return self.compile_dmetrics(stages=stages) - - def get_stages(self): - """Use the metadata to give a prediction of the media being pumped at - each time point. Works for traditional metadata (pre-fluigent). - - Returns: ------ A list of tuples where in each the first value - is the active pump's contents and the second its associated - range of time points - """ - fpath = list(self.grouper.signals.values())[0].filename - with h5py.File(fpath, "r") as f: - tinterval = f.attrs.get("time_settings/timeinterval", None)[0] - tnorm = tinterval / 60 - switch_times = f.attrs.get("switchtimes", None) / tnorm - last_tp = ( - f.attrs.get("time_settings/totaltime", None)[0] / tinterval - ) - pump_contents = f.attrs.get("pumpinit/contents", None) - init_frate = f.attrs.get("pumpinit/flowrate", None) - prate = f.attrs.get("pumprate", None) - main_pump = np.array((init_frate.argmax(), *prate.argmax(axis=0))) - - intervals = np.array((0, *switch_times, last_tp), dtype=int) - - extracted_tps = self.grouper.ntimepoints - stages = [ # Only add intervals with length larger than zero - ( - ": ".join((str(i + 1), pump_contents[p_id])), - range(intervals[i], min(intervals[i + 1], extracted_tps)), - ) - for i, p_id in enumerate(main_pump) - if (intervals[i + 1] > intervals[i]) - ] - return stages - - def compile_growth_metrics( - self, - min_nbuddings: int = 2, - ): - """Filter mothers with n number of buddings and get their metrics. - - Select cells with at least two recorded buddings - """ - names_signals = { - "dvol": "postprocessing/dsignal/postprocessing_savgol_extraction_general_None_volume", - "bud_dvol": "postprocessing/bud_metric/postprocessing_dsignal_postprocessing_savgol_extraction_general_None_volume", - "buddings": "postprocessing/buddings/extraction_general_None_volume", - } - operations = { - "dvol": ("dvol", "max"), - "bud_dvol": ("bud_dvol", "max"), - "buddings": ("buddings", "sum"), - "cycle_length_mean": ( - "buddings", - lambda x: bn.nanmean(np.diff(np.where(x)[0])), - ), - "cycle_length_min": ( - "buddings", - lambda x: bn.nanmin(np.diff(np.where(x)[0])), - ), - "cycle_length_median": ( - "buddings", - lambda x: np.nanmedian(np.diff(np.where(x)[0])), - ), - } - - input_signals = { - k: self.grouper.concat_signal(v) for k, v in names_signals.items() - } - ids = self.get_shared_ids(input_signals, min_nbuddings=min_nbuddings) - - compiled_df = pd.DataFrame( - { - k: getattr(input_signals[sig].loc[ids], op)(axis=1) - if isinstance(op, str) - else input_signals[sig].loc[ids].apply(op, axis=1) - for k, (sig, op) in operations.items() - } - ) - return compiled_df - - def get_shared_ids( - self, input_signals: Dict[str, pd.DataFrame], min_nbuddings: int = None - ): - """Get the intersection id of multiple signals. - - "buddings" must be one the keys in input_signals to use the - argument min_nbuddings. - """ - ids = list(input_signals.values())[0].index - if min_nbuddings is not None: - ids = ( - input_signals["buddings"] - .loc[input_signals["buddings"].sum(axis=1) >= min_nbuddings] - .index - ) - for v in input_signals.values(): - ids = ids.intersection(v.index) - - return ids - - def compile_ncells(self): - df = self.count_cells() - df = df.melt(ignore_index=False) - df.columns = ["timepoint", "ncells_pertrap"] - - return df - - def compile_last_valid_tp(self) -> pd.Series: - """Last valid timepoint per position.""" - df = self.count_cells() - df = df.apply(lambda x: x.last_valid_index(), axis=1) - df = df.groupby(["group", "position"]).max() - - return df - - def compile_slices(self, nslices=2, **kwargs): - tps = [ - min( - i * (self.grouper.ntimepoints // nslices), - self.grouper.ntimepoints - 1, - ) - for i in range(nslices + 1) - ] - slices = [self.compile_slice(tp=tp, **kwargs) for tp in tps] - slices_df = pd.concat(slices) - - slices_df["timepoint"] = np.concatenate( - [np.repeat(tp, len(slice_df)) for tp, slice_df in zip(tps, slices)] - ) - - return slices_df - - def compile_slice_end(self, **kwargs): - return self.compile_slice(tp=-1, **kwargs) - - def guess_metrics(self, metrics: Dict[str, Tuple[str]] = None): - """First approach at autoselecting certain signals for automated - analysis.""" - - if metrics is None: - metrics = { - "GFP": ("median", "max5"), - "mCherry": ("median", "max5"), - # "general": ("eccentricity",), - "Flavin": ("median",), - "postprocessing/savgol": ("volume",), - "dsignal/postprocessing_savgol": ("volume",), - "bud_metric.*dsignal.*savgol": ("volume",), - "ph_ratio": ("median",), - } - - sigs = self.grouper.siglist - selection = { - ".".join((ch, metric)): sig - for sig in sigs - for ch, metric_set in metrics.items() - for metric in metric_set - if re.search("(?!.*bgsub).*".join((ch, metric)) + "$", sig) - } - return selection - - def compile_fluorescence( - self, - metrics: Dict[str, Tuple[str]] = None, - norm: tuple = None, - **kwargs, - ): - """Get a single signal per.""" - if norm is None: - norm = ( - "GFP", - "GFPFast", - "ph_ratio", - "Flavin", - "Citrine", - "mCherry", - ) - - selection = self.guess_metrics(metrics) - - input_signals = { - k: self.grouper.concat_signal(v, **kwargs) - for k, v in selection.items() - } - - # ids = self.get_shared_ids(input_signals) - - to_concat = [] - - def format_df(df): - return df.melt( - ignore_index=False, var_name="timepoint" - ).reset_index() - - for k, v in input_signals.items(): - tmp_formatted = format_df(v) - tmp_formatted["signal"] = k - to_concat.append(tmp_formatted) - if norm and k.split(".")[0] in norm: - norm_v = v.subtract(v.min(axis=1), axis=0).div( - v.max(axis=1) - v.min(axis=1), axis=0 - ) - # norm_v = v.groupby(["position", "trap", "cell_label"]).transform( - # # lambda x: x - x.min() / (x.max() - x.min()) - # lambda x: (x - x.min()) - # / (x.max() - x.min()) - # ) - formatted = format_df(norm_v) - formatted["signal"] = "norm_" + k - to_concat.append(formatted) - - concated = pd.concat(to_concat, axis=0) - - return concated - - def compile_slice( - self, sigloc=None, tp=None, metrics=None, mode=None, **kwargs - ) -> pd.DataFrame: - if sigloc is None: - self.sigloc = "extraction/general/None/volume" - - if tp is None: - tp = 0 - - if metrics is None: - metrics = ("max", "mean", "median", "count", "std", "sem") - - if mode is None: - mode = True - - df = pd.concat( - [ - getattr( - self.get_tp(sigloc=sigloc, tp=tp, mode=mode, **kwargs) - .groupby(["group", "position", "trap"]) - .max() - .groupby(["group", "position"]), - met, - )() - for met in metrics - ], - axis=1, - ) - - df.columns = metrics - - merged = self.add_column(df, self.ntraps, name="ntraps") - - return merged - - @staticmethod - def add_column(df: pd.DataFrame, new_values_d: dict, name="new_col"): - if name in df.columns: - warnings.warn( - "ExpCompiler: Replacing existing column in compilation" - ) - df[name] = [ - new_values_d[pos] for pos in df.index.get_level_values("position") - ] - - return df - - @staticmethod - def traploc_diffs(traplocs: ndarray) -> list: - """Obtain metrics for trap localisation. - - Parameters ---------- traplocs : ndarray (x,2) 2-dimensional - array with the x,y coordinates of traps in each column - Examples -------- FIXME: Add docs. - """ - signal = np.zeros((traplocs.max(), 2)) - for i in range(2): - counts = Counter(traplocs[:, i]) - for j, v in counts.items(): - signal[j - 1, i] = v - - diffs = [ - np.diff(x) - for x in np.apply_along_axis(find_peaks, 0, signal, distance=10)[0] - ] - return diffs - - def compile_delta_traps(self): - group_names = self.grouper.group_names - tups = [ - (group_names[pos], pos, axis, val) - for pos, coords in self.grouper.traplocs().items() - for axis, vals in zip(("x", "y"), self.traploc_diffs(coords)) - for val in vals - ] - - return pd.DataFrame( - tups, columns=["group", "position", "axis", "value"] - ) - - def compile_pertrap_metric( - self, - ranges: Iterable[Iterable[int]] = [ - [0, -1], - ], - metric: str = "count", - ): - """Get the number of cells per trap present during the given ranges.""" - sig = self.concat_signal() - - for i, rngs in enumerate(ranges): - for j, edge in enumerate(rngs): - if edge < 0: - ranges[i][j] = sig.shape[1] - i + 1 - df = pd.concat( - [ - self.get_filled_trapcounts( - sig.loc(axis=1)[slice(*rng)], metric=metric - ) - for rng in ranges - ], - axis=1, - ) - return df.astype(str) - - def get_filled_trapcounts( - self, signal: pd.DataFrame, metric: str - ) -> pd.Series: - present = signal.apply( - lambda x: (not x.first_valid_index()) - & (x.last_valid_index() == len(x) - 1), - axis=1, - ) - results = getattr( - signal.loc[present] - .iloc[:, 0] - .groupby(["group", "position", "trap"]), - metric, - )() - filled = self.fill_trapcount(results) - return filled - - def fill_trapcount( - self, srs: pd.Series, fill_value: Union[int, float] = 0 - ) -> pd.Series: - """Fill the last level of a MultiIndex in a pd.Series. - - Use self to get the max number of traps per position and use - this information to add rows with empty values (with plottings - of distributions in mind) Parameters ---------- srs : pd.Series - Series with a pd.MultiIndex index self : ExperimentSelf - class with 'ntraps' information that returns a dictionary with - position -> ntraps. fill_value : Union[int, float] Value - used to fill new rows. Returns ------- pd.Series Series - with no numbers skipped on the last level. Examples -------- - FIXME: Add docs. - """ - - all_sets = set( - [ - (pos, i) - for pos, ntraps in self.ntraps.items() - for i in range(ntraps) - ] - ) - dif = all_sets.difference( - set( - zip( - *[ - srs.index.get_level_values(i) - for i in ("position", "trap") - ] - ) - ).difference() - ) - new_indices = pd.MultiIndex.from_tuples( - [ - (self.grouper.group_names[idx[0]], idx[0], np.uint(idx[1])) - for idx in dif - ] - ) - new_indices = new_indices.set_levels( - new_indices.levels[-1].astype(np.uint), level=-1 - ) - empty = pd.Series(fill_value, index=new_indices, name="ncells") - return pd.concat((srs, empty)) - - -class Reporter(object): - """Manages Multiple pages to generate a report.""" - - def __init__( - self, - data: Dict[str, pd.DataFrame], - pages: dict = None, - path: str = None, - ): - self.data = data - - if pages is None: - pages = { - "qa": self.gen_page_qa(), - "growth": self.gen_page_growth(), - "fluorescence": self.gen_page_fluorescence(), - } - self.pages = pages - - if path is not None: - self.path = path - - self.porgs = {k: PageOrganiser(data, v) for k, v in pages.items()} - - @property - def pdf(self): - return self._pdf - - @pdf.setter - def pdf(self, path: str): - self._pdf = PdfPages(path) - - def plot_report(self, path: str = None): - if path is None: - path = self.path - - with PdfPages(path) as pdf: - for page_org in list(self.porgs.values())[::-1]: - page_org.plot_page() - pdf.savefig(page_org.fig) - # pdf.savefig() - plt.close() - - @staticmethod - def gen_page_qa(): - page_qc = ( - { - "data": "slice", - "func": "barplot", - "args": ("ntraps", "position"), - "kwargs": {"hue": "group", "palette": "muted"}, - "loc": (0, 0), - }, - { - "data": "delta_traps", - "func": "barplot", - "args": ("axis", "value"), - "kwargs": { - "hue": "group", - }, - "loc": (0, 1), - }, - { - "data": "slices", - "func": "violinplot", - "args": ("group", "median"), - "kwargs": { - "hue": "timepoint", - }, - "loc": (2, 1), - }, - { - "data": "pertrap_metric", - "func": "histplot", - "args": (0, None), - "kwargs": { - "hue": "group", - "multiple": "dodge", - "discrete": True, - }, - "loc": (2, 0), - }, - { - "data": "ncells", - "func": "lineplot", - "args": ("timepoint", "ncells_pertrap"), - "kwargs": { - "hue": "group", - }, - "loc": (1, 1), - }, - { - "data": "last_valid_tp", - "func": "stripplot", - "args": (0, "position"), - "kwargs": { - "hue": "group", - }, - "loc": (1, 0), - }, - ) - return page_qc - - @staticmethod - def gen_page_fluorescence(): - return ( - { - "data": "fluorescence", - "func": "relplot", - "args": ("timepoint", "value"), - "kwargs": { - "col": "signal", - "col_wrap": 2, - "hue": "group", - "facet_kws": {"sharey": False, "sharex": True}, - "kind": "line", - }, - }, - ) - - def gen_page_cell_cell_corr(): - pass - - @staticmethod - def gen_page_growth(): - return ( - { - "data": "stages_dmetric", - "func": "catplot", - "args": ("stage", "value"), - "kwargs": { - "hue": "group", - "col": "growth_metric", - "col_wrap": 2, - "kind": "box", - "sharey": False, - }, - }, - ) - - def gen_all_instructions(self): - qa = self.gen_page_qa() - growth = self.gen_page_growth() - - return (qa, growth) - - -class PageOrganiser(object): - """Add multiple plots to a single page, wither using seaborn multiplots or - manual GridSpec.""" - - def __init__( - self, - data: Dict[str, pd.DataFrame], - instruction_set: Iterable = None, - grid_spec: tuple = None, - fig_kws: dict = None, - ): - self.instruction_set = instruction_set - self.data = {k: df for k, df in data.items()} - - self.single_fig = True - if len(instruction_set) > 1: - self.single_fig = False - - if not self.single_fig: # Select grid_spec with location info - if grid_spec is None: - locs = np.array( - [x.get("loc", (0, 0)) for x in instruction_set] - ) - grid_spec = locs.max(axis=0) + 1 - - if fig_kws is None: - self.fig = plt.figure(dpi=300) - self.fig.set_size_inches(8.27, 11.69, forward=True) - plt.figtext(0.02, 0.99, "", fontsize="small") - self.gs = plt.GridSpec(*grid_spec, wspace=0.3, hspace=0.3) - - self.axes = {} - reset_index = ( - lambda df: df.reset_index().sort_values("position") - if isinstance(df.index, pd.core.indexes.multi.MultiIndex) - else df.sort_values("position") - ) - self.data = {k: reset_index(df) for k, df in self.data.items()} - - def place_plot(self, func, xloc=None, yloc=None, **kwargs): - if xloc is None: - xloc = 0 - if yloc is None: - yloc = 0 - - if ( - self.single_fig - ): # If plotting using a figure method using seaborn cols/rows - self.g = func(**kwargs) - self.axes = { - ax.title.get_text().split("=")[-1][1:]: ax - for ax in self.g.axes.flat - } - self.fig = self.g.fig - else: - self.axes[(xloc, yloc)] = self.fig.add_subplot(self.gs[xloc, yloc]) - func( - ax=self.axes[(xloc, yloc)], - **kwargs, - ) - - # Eye candy - if np.any( # If there is a long label, rotate them all - [ - len(lbl.get_text()) > 8 - for ax in self.axes.values() - for lbl in ax.get_xticklabels() - ] - ) and hasattr(self, "g"): - for axes in self.g.axes.flat: - _ = axes.set_xticklabels( - axes.get_xticklabels(), - rotation=15, - horizontalalignment="right", - ) - - def plot_page( - self, instructions: Iterable[Dict[str, Union[str, Iterable]]] = None - ): - if instructions is None: - instructions = self.instruction_set - if isinstance(instructions, dict): - how = (instructions,) - - for how in instructions: - self.place_plot( - self.gen_sns_wrapper(how), - *how.get("loc", (None, None)), - ) - - def gen_sns_wrapper(self, how): - def sns_wrapper(ax=None): - kwargs = how.get("kwargs", {}) - if ax: - kwargs["ax"] = ax - elif "height" not in kwargs: - ncols = kwargs.get("col_wrap", 1) - if "col" in kwargs: - nrows = np.ceil( - len(np.unique(self.data[how["data"]][kwargs["col"]])) - / ncols - ) - else: - nrows = len( - np.unique(self.data[how["data"]][kwargs["row"]]) - ) - - kwargs["height"] = 11.7 - # kwargs["aspect"] = 8.27 / (11.7 / kwargs["col_wrap"]) - kwargs["aspect"] = (8.27 / ncols) / (kwargs["height"] / nrows) - return getattr(sns, how["func"])( - data=self.data[how["data"]], - x=how["args"][0], - y=how["args"][1], - **kwargs, - ) - - return sns_wrapper - - -# fpath = "/home/alan/Documents/dev/skeletons/scripts/aggregates_exploration/18616_2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01" -# # compiler = ExperimentCompiler(None, base_dir / dir) -# compiler = ExperimentCompiler(None, fpath) -# dfs = compiler.run() -# rep = Reporter(data=dfs, path=Path(fpath) / "report.pdf") -# rep.plot_report("./report.pdf") -# base_dir = Path("/home/alan/Documents/dev/skeletons/scripts/data/") -# for dir in dirs: -# try: -# compiler = ExperimentCompiler(None, base_dir / dir) -# dfs = compiler.run() -# rep = Reporter(data=dfs, path=base_dir / (dir + "/report.pdf")) -# from time import time - -# rep.plot_report(base_dir / (dir + "/report.pdf")) -# except Exception as e: -# print("LOG:ERROR:", e) -# with open("errors.log", "a") as f: -# f.write(e) diff --git a/src/postprocessor/routines/__init__.py b/src/postprocessor/routines/__init__.py deleted file mode 100644 index 298e3378860c408b896f3b1ec70273338053fe26..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Routines for analysing post-processed data that don't follow the parameters-processes structure. - -Routines for analysing post-processed data that don't follow the -parameters-processes structure. - -Currently, these consist of plotting routines. There is one module for each -plotting routine. Each module consists of two components and is structured as -follows: -1. An internal class. - The class defines the parameters and defines additional class attributes to - help with plotting. The class also has one method (`plot`) that takes a - `matplotlib.Axes` object as an argument. This method draws the plot on the - `Axes` object. -2. A plotting function. - The user accesses this function. This function defines the default - parameters in its arguments. Within the function, a 'plotter' object is - defined using the internal class and then the function draws the plot on a - specified `matplotlib.Axes` object. - -This structure follows that of plotting routines in `seaborn` -(https://github.com/mwaskom/seaborn), a Python visualisation library that is -based on `matplotlib`. -""" diff --git a/src/postprocessor/routines/boxplot.py b/src/postprocessor/routines/boxplot.py deleted file mode 100644 index ea5e9ba17bbb66bd2c2f403a3a5392a4c8813d6b..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/boxplot.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import seaborn as sns - -from postprocessor.routines.plottingabc import BasePlotter - - -class _BoxplotPlotter(BasePlotter): - """Draw boxplots over time""" - - def __init__( - self, - trace_df, - trace_name, - unit_scaling, - box_color, - xtick_step, - xlabel, - plot_title, - ): - super().__init__(trace_name, unit_scaling, xlabel, plot_title) - # Define attributes from arguments - self.trace_df = trace_df - self.box_color = box_color - self.xtick_step = xtick_step - - # Define some labels - self.ylabel = "Normalised " + self.trace_name + " fluorescence (AU)" - - # Define horizontal axis ticks and labels - # hacky! -- redefine column names - trace_df.columns = trace_df.columns * self.unit_scaling - self.fmt = ticker.FuncFormatter( - lambda x, pos: "{0:g}".format( - x / (self.xtick_step / self.unit_scaling) - ) - ) - - def plot(self, ax): - """Draw the heatmap on the provided Axes.""" - super().plot(ax) - ax.xaxis.set_major_formatter(self.fmt) - sns.boxplot( - data=self.trace_df, - color=self.box_color, - linewidth=1, - ax=ax, - ) - ax.xaxis.set_major_locator( - ticker.MultipleLocator(self.xtick_step / self.unit_scaling) - ) - - -def boxplot( - trace_df, - trace_name, - unit_scaling=1, - box_color="b", - xtick_step=60, - xlabel="Time (min)", - plot_title="", - ax=None, -): - """Draw series of boxplots from an array of time series of traces - - Draw series of boxplots from an array of time series of traces, showing the - distribution of values at each time point over time. - - Parameters - ---------- - trace_df : pandas.DataFrame - Time series of traces (rows = cells, columns = time points). - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - box_color : string - matplolib colour string, specifies colour of boxes in boxplot - xtick_step : int or float - Interval length, in unit time, to draw x axis ticks. - xlabel : string - x axis label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Returns - ------- - ax : matplotlib Axes - Axes object with the heatmap. - - Examples - -------- - FIXME: Add docs. - - """ - - plotter = _BoxplotPlotter( - trace_df, - trace_name, - unit_scaling, - box_color, - xtick_step, - xlabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax diff --git a/src/postprocessor/routines/heatmap.py b/src/postprocessor/routines/heatmap.py deleted file mode 100644 index 3dd1af967a32c5572a1ae2742c9a3952284738a7..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/heatmap.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import cm, ticker - -from postprocessor.core.processes.standardscaler import standardscaler -from postprocessor.routines.plottingabc import BasePlotter - - -class _HeatmapPlotter(BasePlotter): - """Draw heatmap""" - - def __init__( - self, - trace_df, - trace_name, - buddings_df, - cmap, - unit_scaling, - xtick_step, - scale, - robust, - xlabel, - ylabel, - cbarlabel, - plot_title, - ): - super().__init__(trace_name, unit_scaling, xlabel, plot_title) - # Define attributes from arguments - self.trace_df = trace_df - self.buddings_df = buddings_df - self.cmap = cmap - self.xtick_step = xtick_step - self.scale = scale - self.robust = robust - - # Define some labels - self.cbarlabel = cbarlabel - self.ylabel = ylabel - - # Scale - if self.scale: - self.trace_scaled = standardscaler.as_function(self.trace_df) - else: - self.trace_scaled = self.trace_df - - # If robust, redefine colormap scale to remove outliers - if self.robust: - self.vmin = np.nanpercentile(self.trace_scaled, 2) - self.vmax = np.nanpercentile(self.trace_scaled, 98) - # Make axes even - if self.scale: - if np.abs(self.vmin) > np.abs(self.vmax): - self.vmax = -self.vmin - else: - self.vmin = -self.vmax - else: - self.vmin = None - self.vmax = None - - # Define horizontal axis ticks and labels - # hacky! -- redefine column names - trace_df.columns = trace_df.columns * self.unit_scaling - self.fmt = ticker.FuncFormatter( - lambda x, pos: "{0:g}".format(x * self.unit_scaling) - ) - - def plot(self, ax, cax): - """Draw the heatmap on the provided Axes.""" - super().plot(ax) - ax.xaxis.set_major_formatter(self.fmt) - # Draw trace heatmap - trace_heatmap = ax.imshow( - self.trace_scaled, - cmap=self.cmap, - interpolation="none", - vmin=self.vmin, - vmax=self.vmax, - ) - # Horizontal axis labels as multiples of xtick_step, taking - # into account unit scaling - ax.xaxis.set_major_locator( - ticker.MultipleLocator(self.xtick_step / self.unit_scaling) - ) - # Overlay buddings, if present - if self.buddings_df is not None: - # Must be masked array for transparency - buddings_array = self.buddings_df.to_numpy() - buddings_heatmap_mask = np.ma.masked_where( - buddings_array == 0, buddings_array - ) - # Overlay - ax.imshow( - buddings_heatmap_mask, - interpolation="none", - ) - # Draw colour bar - ax.figure.colorbar( - mappable=trace_heatmap, cax=cax, ax=ax, label=self.cbarlabel - ) - - -def heatmap( - trace_df, - trace_name, - buddings_df=None, - cmap=cm.RdBu, - unit_scaling=1, - xtick_step=60, - scale=True, - robust=True, - xlabel="Time (min)", - ylabel="Cell", - cbarlabel="Normalised fluorescence (AU)", - plot_title="", - ax=None, - cbar_ax=None, -): - """Draw heatmap from an array of time series of traces - - Parameters - ---------- - trace_df : pandas.DataFrame - Time series of traces (rows = cells, columns = time points). - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - buddings_df : pandas.DataFrame - Birth mask (rows = cells, columns = time points). Elements should be - 0 or 1. - cmap : matplotlib ColorMap - Colour map for heatmap. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - xtick_step : int or float - Interval length, in unit time, to draw x axis ticks. - scale : bool - Whether to use standard scaler to scale the trace time series. - robust : bool - If True, the colour map range is computed with robust quantiles instead - of the extreme values. - xlabel : string - x axis label. - ylabel : string - y axis label. - cbarlabel : string - Colour bar label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - cbar_ax : matplotlib Axes - Axes in which to draw the colour bar, otherwise take space from the main - Axes. - - Returns - ------- - ax : matplotlib Axes - Axes object with the heatmap. - - Examples - -------- - FIXME: Add docs. - - """ - plotter = _HeatmapPlotter( - trace_df, - trace_name, - buddings_df, - cmap, - unit_scaling, - xtick_step, - scale, - robust, - xlabel, - ylabel, - cbarlabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax, cbar_ax) - return ax diff --git a/src/postprocessor/routines/histogram.py b/src/postprocessor/routines/histogram.py deleted file mode 100644 index eff7f8d720b707c694ec655da515ab53f3e0d539..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/histogram.py +++ /dev/null @@ -1,137 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import numpy as np - - -class _HistogramPlotter: - """Draw histogram""" - - def __init__( - self, - values, - label, - color, - binsize, - lognormal, - lognormal_base, - xlabel, - ylabel, - plot_title, - ): - # Define attributes from arguments - self.values = values - self.label = label - self.color = color - self.binsize = binsize - self.lognormal = lognormal - self.lognormal_base = lognormal_base - self.xlabel = xlabel - self.ylabel = ylabel - self.plot_title = plot_title - - # Define median - self.median = np.median(self.values) - - # Define bins - if self.lognormal: - self.bins = np.logspace( - 0, - np.ceil( - np.log(np.nanmax(values)) / np.log(self.lognormal_base) - ), - base=self.lognormal_base, - ) # number of bins will be 50 by default, as it's the default in np.logspace - else: - self.bins = np.arange( - self.binsize * (np.nanmin(values) // self.binsize - 2), - self.binsize * (np.nanmax(values) // self.binsize + 2), - self.binsize, - ) - - def plot(self, ax): - """Plot histogram onto specified Axes.""" - ax.set_ylabel(self.ylabel) - ax.set_xlabel(self.xlabel) - ax.set_title(self.plot_title) - - if self.lognormal: - ax.set_xscale("log") - ax.hist( - self.values, - self.bins, - alpha=0.5, - color=self.color, - label=self.label, - ) - ax.axvline( - self.median, - color=self.color, - alpha=0.75, - label="median " + self.label, - ) - ax.legend(loc="upper right") - - -def histogram( - values, - label, - color="b", - binsize=5, - lognormal=False, - lognormal_base=10, - xlabel="Time (min)", - ylabel="Number of occurences", - plot_title="Distribution", - ax=None, -): - """Draw histogram with median indicated - - Parameters - ---------- - values : array_like - Input values for histogram - label : string - Name of value being plotting, e.g. cell division cycle length. - color : string - Colour of bars. - binsize : float - Bin size. - lognormal : bool - Whether to use a log scale for the horizontal axis. - lognormal_base : float - Base of the log scale, if lognormal is True. - xlabel : string - x axis label. - ylabel : string - y axis label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Returns - ------- - ax : matplotlib Axes - Axes object with the histogram. - - Examples - -------- - FIXME: Add docs. - - """ - plotter = _HistogramPlotter( - values, - label, - color, - binsize, - lognormal, - lognormal_base, - xlabel, - ylabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax diff --git a/src/postprocessor/routines/mean_plot.py b/src/postprocessor/routines/mean_plot.py deleted file mode 100644 index 5a704065eda96c08835c77c563dc4425cd475c52..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/mean_plot.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import numpy as np - -from postprocessor.routines.plottingabc import BasePlotter - - -class _MeanPlotter(BasePlotter): - """Draw mean time series plus standard error.""" - - def __init__( - self, - trace_df, - trace_name, - unit_scaling, - label, - mean_color, - error_color, - mean_linestyle, - mean_marker, - xlabel, - ylabel, - plot_title, - ): - super().__init__(trace_name, unit_scaling, xlabel, plot_title) - # Define attributes from arguments - self.trace_df = trace_df - self.label = label - self.mean_color = mean_color - self.error_color = error_color - self.mean_linestyle = mean_linestyle - self.mean_marker = mean_marker - - # Define some labels - self.ylabel = ylabel - - # Mean and standard error - self.trace_time = ( - np.array(self.trace_df.columns, dtype=float) * self.unit_scaling - ) - self.mean_ts = self.trace_df.mean(axis=0) - self.stderr = self.trace_df.std(axis=0) / np.sqrt(len(self.trace_df)) - - def plot(self, ax): - """Draw lines and shading on provided Axes.""" - super().plot(ax) - ax.plot( - self.trace_time, - self.mean_ts, - color=self.mean_color, - alpha=0.75, - linestyle=self.mean_linestyle, - marker=self.mean_marker, - label="Mean, " + self.label, - ) - ax.fill_between( - self.trace_time, - self.mean_ts - self.stderr, - self.mean_ts + self.stderr, - color=self.error_color, - alpha=0.5, - label="Standard error, " + self.label, - ) - ax.legend(loc="upper right") - - -def mean_plot( - trace_df, - trace_name="flavin", - unit_scaling=1, - label="wild type", - mean_color="b", - error_color="lightblue", - mean_linestyle="-", - mean_marker="", - xlabel="Time (min)", - ylabel="Normalised flavin fluorescence (AU)", - plot_title="", - ax=None, -): - """Plot mean time series of a DataFrame, with standard error shading. - - Parameters - ---------- - trace_df : pandas.DataFrame - Time series of traces (rows = cells, columns = time points). - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - label : string - Name of group being plotted, e.g. a strain name. - mean_color : string - matplotlib colour string for the mean trace. - error_color : string - matplotlib colour string for the standard error shading. - mean_linestyle : string - matplotlib linestyle argument for the mean trace. - mean_marker : string - matplotlib marker argument for the mean trace. - xlabel : string - x axis label. - ylabel : string - y axis label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Examples - -------- - FIXME: Add docs. - - """ - plotter = _MeanPlotter( - trace_df, - trace_name, - unit_scaling, - label, - mean_color, - error_color, - mean_linestyle, - mean_marker, - xlabel, - ylabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax diff --git a/src/postprocessor/routines/median_plot.py b/src/postprocessor/routines/median_plot.py deleted file mode 100644 index b8cda943d490ec2a45753d005d714e34e43565dd..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/median_plot.py +++ /dev/null @@ -1,133 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt -import numpy as np - -from postprocessor.routines.plottingabc import BasePlotter - - -class _MedianPlotter(BasePlotter): - """Draw median time series plus interquartile range.""" - - def __init__( - self, - trace_df, - trace_name, - unit_scaling, - label, - median_color, - error_color, - median_linestyle, - median_marker, - xlabel, - ylabel, - plot_title, - ): - super().__init__(trace_name, unit_scaling, xlabel, plot_title) - # Define attributes from arguments - self.trace_df = trace_df - self.label = label - self.median_color = median_color - self.error_color = error_color - self.median_linestyle = median_linestyle - self.median_marker = median_marker - - # Define some labels - self.ylabel = ylabel - - # Median and interquartile range - self.trace_time = ( - np.array(self.trace_df.columns, dtype=float) * self.unit_scaling - ) - self.median_ts = self.trace_df.median(axis=0) - self.quartile1_ts = self.trace_df.quantile(0.25) - self.quartile3_ts = self.trace_df.quantile(0.75) - - def plot(self, ax): - """Draw lines and shading on provided Axes.""" - super().plot(ax) - ax.plot( - self.trace_time, - self.median_ts, - color=self.median_color, - alpha=0.75, - linestyle=self.median_linestyle, - marker=self.median_marker, - label="Median, " + self.label, - ) - ax.fill_between( - self.trace_time, - self.quartile1_ts, - self.quartile3_ts, - color=self.error_color, - alpha=0.5, - label="Interquartile range, " + self.label, - ) - ax.legend(loc="upper right") - - -def median_plot( - trace_df, - trace_name="flavin", - unit_scaling=1, - label="wild type", - median_color="b", - error_color="lightblue", - median_linestyle="-", - median_marker="", - xlabel="Time (min)", - ylabel="Normalised flavin fluorescence (AU)", - plot_title="", - ax=None, -): - """Plot median time series of a DataFrame, with interquartile range - shading. - - Parameters - ---------- - trace_df : pandas.DataFrame - Time series of traces (rows = cells, columns = time points). - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - label : string - Name of group being plotted, e.g. a strain name. - median_color : string - matplotlib colour string for the median trace. - error_color : string - matplotlib colour string for the interquartile range shading. - median_linestyle : string - matplotlib linestyle argument for the median trace. - median_marker : string - matplotlib marker argument for the median trace. - xlabel : string - x axis label. - ylabel : string - y axis label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Examples - -------- - FIXME: Add docs. - """ - plotter = _MedianPlotter( - trace_df, - trace_name, - unit_scaling, - label, - median_color, - error_color, - median_linestyle, - median_marker, - xlabel, - ylabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax diff --git a/src/postprocessor/routines/plot_utils.py b/src/postprocessor/routines/plot_utils.py deleted file mode 100644 index cb33009995d1235805c6780b6db1f6183bb5eb87..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/plot_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 - -import numpy as np -from matplotlib import cm, colors - - -def generate_palette_map(df): - """Create a palette map based on the strains in a dataframe""" - strain_list = np.unique(df.index.get_level_values("strain")) - palette_cm = cm.get_cmap("Set1", len(strain_list) + 1) - palette_rgb = [ - colors.rgb2hex(palette_cm(index / len(strain_list))[:3]) - for index, _ in enumerate(strain_list) - ] - palette_map = dict(zip(strain_list, palette_rgb)) - return palette_map diff --git a/src/postprocessor/routines/plottingabc.py b/src/postprocessor/routines/plottingabc.py deleted file mode 100644 index 97b89aa7e95e01e5b937003233f8ba3d37da7f52..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/plottingabc.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 - -from abc import ABC - - -class BasePlotter(ABC): - """Base class for plotting handler classes""" - - def __init__(self, trace_name, unit_scaling, xlabel, plot_title): - """Common attributes""" - self.trace_name = trace_name - self.unit_scaling = unit_scaling - - self.xlabel = xlabel - self.ylabel = None - self.plot_title = plot_title - - def plot(self, ax): - """Template for drawing on provided Axes""" - ax.set_ylabel(self.ylabel) - ax.set_xlabel(self.xlabel) - ax.set_title(self.plot_title) - # Derived classes extends this with plotting functions - - -# TODO: something about the plotting functions at the end of the modules. -# Decorator? diff --git a/src/postprocessor/routines/single_birth_plot.py b/src/postprocessor/routines/single_birth_plot.py deleted file mode 100644 index 362cda0583336b15ec6a07532c78e4d023c469c4..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/single_birth_plot.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt - -from postprocessor.routines.single_plot import _SinglePlotter - - -class _SingleBirthPlotter(_SinglePlotter): - """Draw a line plot of a single time series, but with buddings overlaid""" - - def __init__( - self, - trace_timepoints, - trace_values, - trace_name, - birth_mask, - unit_scaling, - trace_color, - birth_color, - trace_linestyle, - birth_linestyle, - xlabel, - ylabel, - birth_label, - plot_title, - ): - # Define attributes from arguments - super().__init__( - trace_timepoints, - trace_values, - trace_name, - unit_scaling, - trace_color, - trace_linestyle, - xlabel, - ylabel, - plot_title, - ) - # Add some more attributes useful for buddings - self.birth_mask = birth_mask - self.birth_color = birth_color - self.birth_linestyle = birth_linestyle - self.birth_label = birth_label - - def plot(self, ax): - """Draw the line plots on the provided Axes.""" - trace_time = self.trace_timepoints * self.unit_scaling - super().plot(ax) - birth_mask_bool = self.birth_mask.astype(bool) - for occurence, birth_time in enumerate(trace_time[birth_mask_bool]): - if occurence == 0: - label = self.birth_label - else: - label = None - ax.axvline( - birth_time, - color=self.birth_color, - linestyle=self.birth_linestyle, - label=label, - ) - ax.legend() - - -def single_birth_plot( - trace_timepoints, - trace_values, - trace_name="flavin", - birth_mask=None, - unit_scaling=1, - trace_color="b", - birth_color="k", - trace_linestyle="-", - birth_linestyle="--", - xlabel="Time (min)", - ylabel="Normalised flavin fluorescence (AU)", - birth_label="budding event", - plot_title="", - ax=None, -): - """Plot time series of trace, overlaid with buddings - - Parameters - ---------- - trace_timepoints : array_like - Time points (as opposed to the actual times in time units) - trace_values : array_like - Trace to plot - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - birth_mask : array_like - Mask to indicate where buddings are. Expect values of '0' and '1' or - 'False' and 'True' in the elements. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - trace_color : string - matplotlib colour string for the trace - birth_color : string - matplotlib colour string for the vertical lines indicating buddings - trace_linestyle : string - matplotlib linestyle argument for the trace - birth_linestyle : string - matplotlib linestyle argument for the vertical lines indicating buddings - xlabel : string - x axis label. - ylabel : string - y axis label. - birth_label : string - label for budding event, 'budding event' by default. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Returns - ------- - ax : matplotlib Axes - Axes object with the plot. - - Examples - -------- - FIXME: Add docs. - - """ - plotter = _SingleBirthPlotter( - trace_timepoints, - trace_values, - trace_name, - birth_mask, - unit_scaling, - trace_color, - birth_color, - trace_linestyle, - birth_linestyle, - xlabel, - ylabel, - birth_label, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax diff --git a/src/postprocessor/routines/single_plot.py b/src/postprocessor/routines/single_plot.py deleted file mode 100644 index 111fb4bc6ea3f33392d0e967b316df8384808300..0000000000000000000000000000000000000000 --- a/src/postprocessor/routines/single_plot.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 - -import matplotlib.pyplot as plt - -from postprocessor.routines.plottingabc import BasePlotter - - -class _SinglePlotter(BasePlotter): - """Draw a line plot of a single time series.""" - - def __init__( - self, - trace_timepoints, - trace_values, - trace_name, - unit_scaling, - trace_color, - trace_linestyle, - xlabel, - ylabel, - plot_title, - ): - super().__init__(trace_name, unit_scaling, xlabel, plot_title) - # Define attributes from arguments - self.trace_timepoints = trace_timepoints - self.trace_values = trace_values - self.trace_color = trace_color - self.trace_linestyle = trace_linestyle - - # Define some labels - self.ylabel = ylabel - - def plot(self, ax): - """Draw the line plot on the provided Axes.""" - super().plot(ax) - ax.plot( - self.trace_timepoints * self.unit_scaling, - self.trace_values, - color=self.trace_color, - linestyle=self.trace_linestyle, - label=self.trace_name, - ) - - -def single_plot( - trace_timepoints, - trace_values, - trace_name="flavin", - unit_scaling=1, - trace_color="b", - trace_linestyle="-", - xlabel="Time (min)", - ylabel="Normalised flavin fluorescence (AU)", - plot_title="", - ax=None, -): - """Plot time series of trace. - - Parameters - ---------- - trace_timepoints : array_like - Time points (as opposed to the actual times in time units). - trace_values : array_like - Trace to plot. - trace_name : string - Name of trace being plotted, e.g. 'flavin'. - unit_scaling : int or float - Unit scaling factor, e.g. 1/60 to convert minutes to hours. - trace_color : string - matplotlib colour string, specifies colour of line plot. - trace_linestyle : string - matplotlib linestyle argument. - xlabel : string - x axis label. - ylabel : string - y axis label. - plot_title : string - Plot title. - ax : matplotlib Axes - Axes in which to draw the plot, otherwise use the currently active Axes. - - Returns - ------- - ax : matplotlib Axes - Axes object with the plot. - - Examples - -------- - FIXME: Add docs. - - """ - plotter = _SinglePlotter( - trace_timepoints, - trace_values, - trace_name, - unit_scaling, - trace_color, - trace_linestyle, - xlabel, - ylabel, - plot_title, - ) - if ax is None: - ax = plt.gca() - plotter.plot(ax) - return ax