Skip to content
Snippets Groups Projects
Commit 567d31e1 authored by Alán Muñoz's avatar Alán Muñoz
Browse files

add mother_assign_dynamic

parent 2cf2e600
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ from core.utils import Cache, accumulate, get_store_path
################### Dask Methods ################################
def format_segmentation(segmentation, tp):
""" Format a single timepoint into a dictionary.
"""Format a single timepoint into a dictionary.
Parameters
------------
......@@ -39,91 +39,102 @@ def format_segmentation(segmentation, tp):
"""
# Segmentation is a list of dictionaries, ordered by trap
# Add trap information
mother_assign = None
# mother_assign = None
for i, x in enumerate(segmentation):
x['trap'] = [i] * len(x['cell_label'])
x["trap"] = [i] * len(x["cell_label"])
# Merge into a dictionary of lists, by column
merged = {k: list(itertools.chain.from_iterable(
res[k] for res in segmentation))
for k in segmentation[0].keys()}
merged = {
k: list(itertools.chain.from_iterable(res[k] for res in segmentation))
for k in segmentation[0].keys()
}
# Special case for mother_assign
if 'mother_assign' in merged:
del merged['mother_assign']
mother_assign = [x['mother_assign'] for x in segmentation]
merged["mother_assign_dynamic"] = merged["mother_assign"]
if "mother_assign" in merged:
del merged["mother_assign"]
mother_assign = [x["mother_assign"] for x in segmentation]
# Check that the lists are all of the same length (in case of errors in
# BABY)
n_cells = min([len(v) for v in merged.values()])
merged = {k: v[:n_cells] for k, v in merged.items()}
merged['timepoint'] = [tp] * n_cells
merged['mother_assign'] = mother_assign
merged["timepoint"] = [tp] * n_cells
merged["mother_assign"] = mother_assign
return merged
def choose_model_from_params(modelset_filter=None, camera='prime95b', channel='brightfield',
zoom='60x', n_stacks='5z', **kwargs):
def choose_model_from_params(
modelset_filter=None,
camera="prime95b",
channel="brightfield",
zoom="60x",
n_stacks="5z",
**kwargs,
):
"""
Define which model to query from the server based on a set of parameters.
Parameters
----------
valid_models: List[str]
The names of the models that are available.
modelset_filter: str
modelset_filter: str
A regex filter to apply on the models to start.
camera: str
The camera used in the experiment (case insensitive).
channel:str
channel:str
The channel used for segmentation (case insensitive).
zoom: str
zoom: str
The zoom on the channel.
n_stacks: str
The number of z_stacks to use in segmentation
Returns
-------
model_name : str
"""
valid_models = list(modelsets().keys())
# Apply modelset filter if specified
if modelset_filter is not None:
msf_regex = re.compile(modelset_filter)
valid_models = filter(msf_regex.search, valid_models)
# Apply parameter filters if specified
params = [str(x) if x is not None else '.+' for x in [camera.lower(),
channel.lower(),
zoom, n_stacks]]
params_re = re.compile('^' + '_'.join(params) + '$')
params = [
str(x) if x is not None else ".+"
for x in [camera.lower(), channel.lower(), zoom, n_stacks]
]
params_re = re.compile("^" + "_".join(params) + "$")
valid_models = list(filter(params_re.search, valid_models))
# Check that there are valid models
if len(valid_models) == 0:
raise KeyError(
"No model sets found matching {}".format(', '.join(params)))
raise KeyError("No model sets found matching {}".format(", ".join(params)))
# Pick the first model
return valid_models[0]
class DummyRunner:
"""A BabyRunner object for cell segmentation.
Does segmentation one time point at a time."""
def __init__(self, tiler, *args, **kwargs):
self.tiler = tiler
self.model_config = modelsets()[choose_model_from_params(**kwargs)]
self.brain = BabyBrain(**self.model_config)
self.crawler = BabyCrawler(self.brain)
self.bf_channel = self.tiler.get_channel_index('Brightfield')
self.bf_channel = self.tiler.get_channel_index("Brightfield")
def get_data(self, tp):
# Swap axes x and z, probably shouldn't swap, just move z
return self.tiler.get_tp_data(tp, self.bf_channel)\
.swapaxes(1, 3).swapaxes(1, 2)
return self.tiler.get_tp_data(tp, self.bf_channel).swapaxes(1, 3).swapaxes(1, 2)
def run_tp(self, tp, with_edgemasks=True, assign_mothers=True, **kwargs):
""" Simulating processing time with sleep"""
"""Simulating processing time with sleep"""
# Access the image
img = self.get_data(tp)
segmentation = self.crawler.step(img, with_edgemasks=with_edgemasks, assign_mothers=assign_mothers, **kwargs)
segmentation = self.crawler.step(
img, with_edgemasks=with_edgemasks, assign_mothers=assign_mothers, **kwargs
)
return format_segmentation(segmentation, tp)
......@@ -134,9 +145,10 @@ class DummyClient:
Does segmentation one time point at a time.
Should work better with the parallelisation.
"""
bf_channel = 0
model_name = 'prime95b_brightfield_60x_5z'
url = 'http://localhost:5101'
model_name = "prime95b_brightfield_60x_5z"
url = "http://localhost:5101"
max_tries = 50
sleep_time = 0.1
......@@ -147,7 +159,7 @@ class DummyClient:
@property
def session(self):
if self._session is None:
r_session = requests.get(self.url + f'/session/{self.model_name}')
r_session = requests.get(self.url + f"/session/{self.model_name}")
r_session.raise_for_status()
self._session = r_session.json()["sessionid"]
return self._session
......@@ -158,16 +170,19 @@ class DummyClient:
def queue_image(self, img, **kwargs):
bit_depth = img.dtype.itemsize * 8 # bit depth = byte_size * 8
data = create_request(img.shape, bit_depth, img, **kwargs)
status = requests.post(self.url + f'/segment?sessionid={self.session}',
data=data,
headers={'Content-Type': data.content_type})
status = requests.post(
self.url + f"/segment?sessionid={self.session}",
data=data,
headers={"Content-Type": data.content_type},
)
status.raise_for_status()
return status
def get_segmentation(self):
try:
seg_response = requests.get(
self.url + f'/segment?sessionid={self.session}', timeout=120)
self.url + f"/segment?sessionid={self.session}", timeout=120
)
seg_response.raise_for_status()
result = seg_response.json()
except Timeout as e:
......@@ -191,6 +206,7 @@ class DummyClient:
continue
return format_segmentation(seg, tp)
################### Old Methods #################################
# class BabyNoMatches(Exception):
# pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment