diff --git a/examples/create_tsv.py b/examples/create_tsv.py deleted file mode 100644 index 74a1d5198aedb53db28957b0dffd19b5e8d09193..0000000000000000000000000000000000000000 --- a/examples/create_tsv.py +++ /dev/null @@ -1,10 +0,0 @@ -from wela.dataloader import dataloader - -datasets = ["Pdr5_3_11_22", "Pdr5_flc_10ugml_1676"] -for dataname in datasets: - dl = dataloader( - h5dir="/Users/pswain/ecdf/swainlab/aliby_datasets/ivan", - wdir="/Users/pswain/wip/tsv_data", - ) - dl.load(dataname, key_index="median_GFP") - dl.save() diff --git a/examples/run_imageviewer.py b/examples/run_imageviewer.py deleted file mode 100644 index 40ccbdcf2561d3cd63968298acae7078bac3ea8d..0000000000000000000000000000000000000000 --- a/examples/run_imageviewer.py +++ /dev/null @@ -1,27 +0,0 @@ -from wela.imageviewer import ImageViewer - -aliby_output = "/Users/pswain/wip/aliby_output/" -omero_name = "2104_2024_03_29_Aggregation_to_0pc_glc_sorb_00" -position = "ch11_Gcd6_001" - -# ADD details -server_info = { - "host": "", - "username": "", - "password": "", -} - -h5file = f"{aliby_output}{omero_name}/{position}.h5" -iv = ImageViewer.remote(h5file, server_info, 2104) -tpt_end = 10 -no_cells = 4 -# for information only -iv.print_traps_with_cells(tpt_end) -# use Napari to view cells -# python -m pip install "napari[all]" --upgrade -iv.view( - trap_ids=iv.sample_traps_with_cells(tpt_end=tpt_end, no_cells=no_cells), - tpt_end=tpt_end, - channels_to_skip=["cy5"], - no_vertical_tiles=1, -) diff --git a/examples/run_local_imageviewer.py b/examples/run_local_imageviewer.py index 89b84b1e1968ae78c306cb3f8bdebaf93a721a2d..1a27c73d035a49d71574817bf35ffb601ecacdba 100644 --- a/examples/run_local_imageviewer.py +++ b/examples/run_local_imageviewer.py @@ -1,4 +1,4 @@ -from wela.imageviewer import ImageViewer, get_files +from wela.imageviewer import ImageViewer, get_files_zarr aliby_input = "/Users/pswain/wip/aliby_input/" @@ -8,7 +8,7 @@ position = "htb2mCherry_018" # for zarr files iv = ImageViewer.local( - **get_files( + **get_files_zarr( aliby_input, aliby_output, omero_name, diff --git a/examples/run_wela.py b/examples/run_wela.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0c85c061338c138d10c6e8426f6b8334a5d510 --- /dev/null +++ b/examples/run_wela.py @@ -0,0 +1,69 @@ +import sys + +import matplotlib.pylab as plt +import seaborn as sns +from wela.dataloader import dataloader +from wela.imageviewer import ImageViewer, get_h5files +from wela.plotting import kymograph +from wela.sorting import sort_by_budding + +h5dir = "/Users/pswain/wip/aliby_output/ribosomes/" +omids = [ + "1999_2024_01_14_Rpl3_glc_raf_00", + "2003_2024_01_29_2pc_switch_00", + "2008_2024_02_12_Rpl3_glc_raf_switching_00", +] +# FILL IN +server_info = { + "host": "staffa.bio.ed.ac.uk", + "username": "", + "password": "", +} +view = True +# pick the experiment to analyse +omid = 2008 + +# 1. Run with view=True to check visually that aliby has worked correctly. +# 2. Set key_index, the signal you are most interested in. +# 3. Run with view=False to run dataloader and save a tsv file. + +if view: + omero_name = [om for om in omids if str(omid) in om][0] + h5files = get_h5files(h5dir, omero_name) + position = h5files[0] + h5file = f"{h5dir}{omero_name}/{position}" + iv = ImageViewer.remote(h5file, server_info, omid) + tpt_end = 10 + no_cells = 6 + iv.view( + trap_ids=iv.sample_traps_with_cells( + tpt_end=tpt_end, no_cells=no_cells + ), + tpt_end=tpt_end, + channels_to_skip=["cy5"], + no_rows=2, + ) + sys.exit(0) + +# run dataloader +key_index = "median_GFP" +dl = dataloader(h5dir, ".") +expt = [omid_full for omid_full in omids if str(omid) in omid_full][0] +dl.load(expt, key_index=key_index, cutoff=0.9) +dl.save() + +# plot kymographs +groups = dl.df.group.unique() +for group in groups: + _, buddings = dl.get_time_series("buddings", group=group) + sort_order = sort_by_budding(buddings) + kymograph( + dl.df[dl.df.group == group], + hue=key_index, + title=group, + sort_order=sort_order, + ) + +# plot means +sns.relplot(data=dl.df, x="time", y=key_index, kind="line", hue="group") +plt.show() diff --git a/src/wela/dataloader.py b/src/wela/dataloader.py index 2dee3425f344982ab76b4031d0427e557514842a..85a6679c4ad1d3e0a64ff42cf4f31f65e301540b 100644 --- a/src/wela/dataloader.py +++ b/src/wela/dataloader.py @@ -189,7 +189,7 @@ class dataloader: self.a2g_dict = {v: k for (k, v) in self.g2a_dict.items()} def include_bud_fluorescence(self, grouper, dataname): - """Add mean and median bud fluorescence to the h5 files.""" + """Add statistics for bud fluorescence to the h5 files.""" # find fluorescence channels channels = list(grouper.channels) channels.remove("Brightfield") @@ -199,6 +199,8 @@ class dataloader: [ f"/extraction/{channel}/max/median", f"/extraction/{channel}/max/mean", + f"/extraction/{channel}/max/total", + f"/extraction/{channel}/max/total_squared", ] for channel in channels ] @@ -220,6 +222,7 @@ class dataloader: hours=True, bud_fluorescence=False, tmax_in_mins_dict=None, + get_all_signals=False, ): """ Load either an experiment from h5 files or a tsv data set. @@ -261,6 +264,9 @@ class dataloader: values. For example: { "PDR5_GFP_001": 6 * 60}. Data will only be include up to this time point, a way to avoid errors in assigning lineages because of clogging. + get_all_signals: boolean + If True, check all h5 files for signals because some positions + may have different signals from others. Returns ------- @@ -293,10 +299,14 @@ class dataloader: # call postprocessor to add bud fluorescence to h5 files self.include_bud_fluorescence(grouper, dataname) print("Signals available:") - for signal in grouper.available: + if get_all_signals: + signals_available = grouper.all_available + else: + signals_available = grouper.available + for signal in signals_available: print(" ", signal) print() - if self.a2g_dict[key_index] not in grouper.available: + if self.a2g_dict[key_index] not in signals_available: raise Exception(f"The key index {key_index} is unavailable.") # get indices for all buds bud_indices = self.get_bud_indices(grouper, key_index) @@ -311,6 +321,7 @@ class dataloader: # add data for other signals to data for key_index r_df = self.load_h5( grouper=grouper, + signals_available=signals_available, key_index=key_index, bud_indices=bud_indices, r_df=r_df, @@ -338,6 +349,7 @@ class dataloader: def load_h5( self, grouper, + signals_available, key_index, bud_indices, r_df, @@ -349,7 +361,7 @@ class dataloader: print("\nGetting data for other signals...") for i, sigpath in enumerate(self.g2a_dict): if ( - sigpath in grouper.available + sigpath in signals_available and not ("buddings" in sigpath or "bud_metric" in sigpath) and sigpath != self.a2g_dict[key_index] ): diff --git a/src/wela/imageviewer.py b/src/wela/imageviewer.py index 629d69ea4485a3bcf624a9f260549e2e0a750b6f..bed51f6c63b394774a2b62a7330d3078e01c4647 100644 --- a/src/wela/imageviewer.py +++ b/src/wela/imageviewer.py @@ -9,7 +9,10 @@ except ModuleNotFoundError: "Napari cannot be imported.\nRun", ' python -m pip install "napari[all]"', ) +from typing import Any, Dict, List + import numpy as np +import numpy.typing as npt from agora.io.cells import Cells from aliby.io.image import dispatch_image from aliby.io.omero import Dataset @@ -47,7 +50,7 @@ class ImageViewer: print(f" Trouble loading {image_file}.") @classmethod - def remote(cls, h5file: str, server_info: dict, omero_id: int): + def remote(cls, h5file: str, server_info: Dict, omero_id: int): """View images from OMERO.""" iv = cls(h5file) with h5py.File(iv.h5file_path, "r") as f: @@ -69,7 +72,9 @@ class ImageViewer: iv.cells = Cells.from_source(iv.h5file_path) return iv - def get_all_traps_with_cells(self, tpt_end, tpt_start=0, display=True): + def get_all_traps_with_cells( + self, tpt_end: int, tpt_start: int = 0, display: bool = True + ): """List traps with cells.""" cells = self.cells tpts = range(tpt_start, tpt_end) @@ -79,24 +84,31 @@ class ImageViewer: if tpt in cells.nonempty_tp_in_trap(trap_id): traps_with_cells.append(trap_id) break + traps_with_cells = np.unique(traps_with_cells) if display: - print(f"Traps with cells {traps_with_cells}") + print(f"Traps with cells {list(traps_with_cells)}") return traps_with_cells - def sample_traps_with_cells(self, no_cells, tpt_end, tpt_start=0): + def sample_traps_with_cells( + self, no_cells: int, tpt_end: int, tpt_start: int = 0 + ): """Sample some traps that have cells.""" traps_with_cells = self.get_all_traps_with_cells( tpt_end, tpt_start, display=False ) - rng = np.random.default_rng() - samples = rng.integers( - low=0, - high=len(traps_with_cells), - size=np.min([no_cells, len(traps_with_cells)]), + samples = np.random.choice( + traps_with_cells, + size=np.min([no_cells, traps_with_cells.size]), + replace=False, ) return samples - def get_tiles(self, trap_id, tps, channels_to_skip=None, cell_only=True): + def get_tiles( + self, + trap_id: int, + tps: List[int], + channels_to_skip: List[str] = None, + ): """Get dict of tiles with channel indices as keys.""" tiles_dict = {} if channels_to_skip is None: @@ -105,7 +117,7 @@ class ImageViewer: channels = [ ch for ch in self.tiler.channels if ch not in channels_to_skip ] - channel_indices = [channels.index(ch) for ch in channels] + channel_indices = [self.tiler.channels.index(ch) for ch in channels] for ch_index, ch in zip(channel_indices, channels): tile_dict_for_ch = self.get_all_tiles(tps, ch_index) tiles = [x[trap_id] for x in tile_dict_for_ch.values()] @@ -132,7 +144,7 @@ class ImageViewer: tiles_dict[ch] = new_tiles return tiles_dict - def get_outlines(self, trap_id, tps): + def get_outlines(self, trap_id: int, tps: List[int]): """Get uniquely labelled outlines for each cell time point.""" # get outlines for each time point outlines = [ @@ -161,14 +173,14 @@ class ImageViewer: def get_all_tiles( self, - tps, - channel_index, - z=0, + tps: List[int], + channel_index: str, + z: int = 0, ): """ Get dict with time points as keys and all available tiles as values. - We assume only a single channel. + Assume only a single channel. """ z = z or self.tiler.ref_z ch_tps = [(channel_index, tp) for tp in tps] @@ -181,7 +193,9 @@ class ImageViewer: tile_dict = {tp: self.full[(ch, tp)] for ch, tp in ch_tps} return tile_dict - def get_data_for_viewing(self, trap_id, tps, channels_to_skip): + def get_data_for_viewing( + self, trap_id: int, tps: List[int], channels_to_skip: List[str] + ): """ Get images and outlines as multidimensional arrays for Napari. @@ -217,11 +231,11 @@ class ImageViewer: ts_labels[tp_index, 0, ...] = outlines[tp_index] return ts_images, ts_labels, channels - def concat(self, arrangement, image_dict, axis): + def concat(self, arrangement: List[int], image_dict: Dict, axis: int): """ Concat dict of images into one image array. - Following the vertical layout in arrangment. + Follow the vertical layout in arrangment. """ # concatenate vertically into a list images_v = [ @@ -239,23 +253,22 @@ class ImageViewer: images = images_v[0] return images - def combine_tiles(self, ts_images_dict, ts_labels_dict, no_vertical_tiles): - """Combine tiles into one image first vertically then horizontally.""" + def combine_tiles( + self, ts_images_dict: Dict, ts_labels_dict: Dict, no_rows: int + ): + """Combine tiles into one image first into rows then columns.""" no_tiles = len(ts_images_dict) trap_ids = list(ts_images_dict.keys()) # find how tiles will be arranged in the concatenated image - if no_tiles < no_vertical_tiles: + if no_tiles < no_rows: arrangement = [trap_ids] else: arrangement = [ - trap_ids[i : min(i + no_vertical_tiles, no_tiles)] + trap_ids[i : min(i + no_rows, no_tiles)] for i in range( 0, - int( - np.floor(no_tiles / no_vertical_tiles) - * no_vertical_tiles - ), - no_vertical_tiles, + int(np.floor(no_tiles / no_rows) * no_rows), + no_rows, ) ] if no_tiles > np.array(arrangement).size: @@ -271,26 +284,30 @@ class ImageViewer: def view( self, - trap_ids, - tpt_end=10, - tpt_start=0, - channels_to_skip=None, - no_vertical_tiles=3, + trap_ids: List[int], + tpt_start: int = 0, + tpt_end: int = 10, + channels_to_skip: List[str] = None, + no_rows: int = 2, ): """ - Use Napari to view all channels and outlines for a particular trap. + Use Napari to view all channels and outlines for particular traps. Fluorescence channels will not be immediately visible. + Concatenating traps into one image can become slow for multiple traps. Parameters ---------- - trap_id: int - The trap to be viewed. - tps: int or array of ints - Either the last time point to be viewed or a rage of time points - to view. - If None, all time points will be viewed, but gathering the images - will be slow. + trap_ids: list of int + The traps to be viewed. + tpt_start: int + The index for the initial time point to view. + tpt_end: int + The index for the final time point. + channels_to_skip: list of str + Channels to ignore, such as "cy5". + no_rows: int + The number of rows of traps in the final concatenated image. """ tps = np.arange(tpt_start, tpt_end + 1) if isinstance(trap_ids, int): @@ -302,13 +319,18 @@ class ImageViewer: ) # combine tiles ts_images, ts_labels = self.combine_tiles( - ts_images_dict, ts_labels_dict, no_vertical_tiles + ts_images_dict, ts_labels_dict, no_rows ) # launch napari self.launch_napari(ts_images, ts_labels, channels) - def launch_napari(self, ts_images, ts_labels, channels): - """Use Napari to see the images and outlines.""" + def launch_napari( + self, + ts_images: npt.NDArray[Any], + ts_labels: npt.NDArray[Any], + channels: List[str], + ): + """Call Napari viewer.""" viewer = napari.Viewer() viewer.add_image( ts_images[:, channels.index("Brightfield"), ...], @@ -328,7 +350,9 @@ class ImageViewer: #### -def colormap(channel): + + +def colormap(channel: str): """Find default colormap.""" if "GFP" in channel: colormap = "green" @@ -339,7 +363,7 @@ def colormap(channel): return colormap -def get_files( +def get_files_zarr( aliby_input: str, aliby_output: str, omero_name: str, @@ -351,3 +375,12 @@ def get_files( image_file_name = h5file.split("/")[-1].split(".")[0] + ".zarr" image_file = str(Path(aliby_input) / omero_name / image_file_name) return {"h5file": h5file, "image_file": image_file} + + +def get_h5files(aliby_output: str, omero_name: str, print=False): + """List all positions for a particular experiment.""" + h5files = [f.name for f in (Path(aliby_output) / omero_name).glob("*.h5")] + if print: + for file in h5files: + print(f"\t{file}") + return h5files diff --git a/src/wela/plotting.py b/src/wela/plotting.py index 89fb108b006c9c55b8f9ba542d7df492164bb254..64d04c8f9f6f39a68e0eec796e5d092820a71ebd 100644 --- a/src/wela/plotting.py +++ b/src/wela/plotting.py @@ -82,7 +82,7 @@ def kymograph( >>> kymograph(dl.df, hue="buddings", group="Msn2") >>> >>> from wela.sorting import sort_by_budding - >>> _, buddings = dl.get_time_series(buddings) + >>> _, buddings = dl.get_time_series("buddings") >>> sort_order = sort_by_budding(buddings) >>> kymograph(dl.df, hue="flavin", sort_order=sort_order) """