Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
alibylite
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Swain Lab
aliby
alibylite
Commits
17b9d3e5
Commit
17b9d3e5
authored
3 years ago
by
Alán Muñoz
Browse files
Options
Downloads
Patches
Plain Diff
add agora and cells
parent
c92f6cd9
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
io/cells.py
+318
-0
318 additions, 0 deletions
io/cells.py
tile/tiler.py
+333
-0
333 additions, 0 deletions
tile/tiler.py
tile/traps.py
+480
-0
480 additions, 0 deletions
tile/traps.py
with
1131 additions
and
0 deletions
io/cells.py
0 → 100644
+
318
−
0
View file @
17b9d3e5
import
logging
from
pathlib
import
Path
,
PosixPath
from
time
import
perf_counter
from
typing
import
Union
from
itertools
import
groupby
from
collections.abc
import
Iterable
from
utils_find_1st
import
find_1st
,
cmp_equal
import
h5py
import
numpy
as
np
from
scipy
import
ndimage
from
scipy.sparse.base
import
isdense
from
agora.io.writer
import
load_complex
def
cell_factory
(
store
,
type
=
"
hdf5
"
):
if
type
==
"
hdf5
"
:
return
CellsHDF
(
store
)
else
:
raise
TypeError
(
"
Could not get cells for type {}:
"
"
valid types are matlab and hdf5
"
)
class
Cells
:
"""
An object that gathers information about all the cells in a given
trap.
This is the abstract object, used for type testing
"""
def
__init__
(
self
):
pass
@staticmethod
def
from_source
(
source
:
Union
[
PosixPath
,
str
],
kind
:
str
=
None
):
if
isinstance
(
source
,
str
):
source
=
Path
(
source
)
if
kind
is
None
:
# Infer kind from filename
kind
=
"
matlab
"
if
source
.
suffix
==
"
.mat
"
else
"
hdf5
"
return
cell_factory
(
source
,
kind
)
@staticmethod
def
_asdense
(
array
):
if
not
isdense
(
array
):
array
=
array
.
todense
()
return
array
@staticmethod
def
_astype
(
array
,
kind
):
# Convert sparse arrays if needed and if kind is 'mask' it fills the outline
array
=
Cells
.
_asdense
(
array
)
if
kind
==
"
mask
"
:
array
=
ndimage
.
binary_fill_holes
(
array
).
astype
(
int
)
return
array
@classmethod
def
hdf
(
cls
,
fpath
):
return
CellsHDF
(
fpath
)
@classmethod
def
mat
(
cls
,
path
):
return
CellsMat
(
matObject
(
store
))
class
CellsHDF
(
Cells
):
def
__init__
(
self
,
filename
,
path
=
"
cell_info
"
):
self
.
filename
=
filename
self
.
cinfo_path
=
path
self
.
_edgem_indices
=
None
self
.
_edgemasks
=
None
self
.
_tile_size
=
None
def
__getitem__
(
self
,
item
):
if
item
==
"
edgemasks
"
:
return
self
.
edgemasks
_item
=
"
_
"
+
item
if
not
hasattr
(
self
,
_item
):
setattr
(
self
,
_item
,
self
.
_fetch
(
item
))
return
getattr
(
self
,
_item
)
def
_get_idx
(
self
,
cell_id
,
trap_id
):
return
(
self
[
"
cell_label
"
]
==
cell_id
)
&
(
self
[
"
trap
"
]
==
trap_id
)
def
_fetch
(
self
,
path
):
with
h5py
.
File
(
self
.
filename
,
mode
=
"
r
"
)
as
f
:
return
f
[
self
.
cinfo_path
][
path
][()]
@property
def
ntraps
(
self
):
with
h5py
.
File
(
self
.
filename
,
mode
=
"
r
"
)
as
f
:
return
len
(
f
[
"
/trap_info/trap_locations
"
][()])
@property
def
traps
(
self
):
return
list
(
set
(
self
[
"
trap
"
]))
@property
def
tile_size
(
self
):
# TODO read from metadata
if
self
.
_tile_size
is
None
:
with
h5py
.
File
(
self
.
filename
,
mode
=
"
r
"
)
as
f
:
self
.
_tile_size
==
f
[
"
trap_info/tile_size
"
][
0
]
return
self
.
_tile_size
@property
def
edgem_indices
(
self
):
if
self
.
_edgem_indices
is
None
:
edgem_path
=
"
edgemasks/indices
"
self
.
_edgem_indices
=
load_complex
(
self
.
_fetch
(
edgem_path
))
return
self
.
_edgem_indices
@property
def
edgemasks
(
self
):
if
self
.
_edgemasks
is
None
:
edgem_path
=
"
edgemasks/values
"
self
.
_edgemasks
=
self
.
_fetch
(
edgem_path
)
return
self
.
_edgemasks
def
_edgem_where
(
self
,
cell_id
,
trap_id
):
ix
=
trap_id
+
1j
*
cell_id
return
find_1st
(
self
.
edgem_indices
==
ix
,
True
,
cmp_equal
)
@property
def
labels
(
self
):
"""
Return all cell labels in object
We use mother_assign to list traps because it is the only propriety that appears even
when no cells are found
"""
return
[
self
.
labels_in_trap
(
trap
)
for
trap
in
self
.
traps
]
def
where
(
self
,
cell_id
,
trap_id
):
"""
Returns
Parameters
----------
cell_id: int
Cell index
trap_id: int
Trap index
Returns
----------
indices int array
boolean mask array
edge_ix int array
"""
indices
=
self
.
_get_idx
(
cell_id
,
trap_id
)
edgem_ix
=
self
.
_edgem_where
(
cell_id
,
trap_id
)
return
(
self
[
"
timepoint
"
][
indices
],
indices
,
edgem_ix
,
)
# FIXME edgem_ix makes output different to matlab's Cell
def
outline
(
self
,
cell_id
,
trap_id
):
times
,
indices
,
cell_ix
=
self
.
where
(
cell_id
,
trap_id
)
return
times
,
self
[
"
edgemasks
"
][
cell_ix
,
times
]
def
mask
(
self
,
cell_id
,
trap_id
):
times
,
outlines
=
self
.
outline
(
cell_id
,
trap_id
)
return
times
,
np
.
array
(
[
ndimage
.
morphology
.
binary_fill_holes
(
o
)
for
o
in
outlines
]
)
def
at_time
(
self
,
timepoint
,
kind
=
"
mask
"
):
ix
=
self
[
"
timepoint
"
]
==
timepoint
cell_ix
=
self
[
"
cell_label
"
][
ix
]
traps
=
self
[
"
trap
"
][
ix
]
indices
=
traps
+
1j
*
cell_ix
choose
=
np
.
in1d
(
self
.
edgem_indices
,
indices
)
edgemasks
=
self
[
"
edgemasks
"
][
choose
,
timepoint
]
masks
=
[
self
.
_astype
(
edgemask
,
kind
)
for
edgemask
in
edgemasks
if
edgemask
.
any
()
]
return
self
.
group_by_traps
(
traps
,
masks
)
def
group_by_traps
(
self
,
traps
,
data
):
# returns a dict with traps as keys and labels as value
iterator
=
groupby
(
zip
(
traps
,
data
),
lambda
x
:
x
[
0
])
d
=
{
key
:
[
x
[
1
]
for
x
in
group
]
for
key
,
group
in
iterator
}
d
=
{
i
:
d
.
get
(
i
,
[])
for
i
in
self
.
traps
}
return
d
def
labels_in_trap
(
self
,
trap_id
):
# Return set of cell ids in a trap.
return
set
((
self
[
"
cell_label
"
][
self
[
"
trap
"
]
==
trap_id
]))
def
labels_at_time
(
self
,
timepoint
):
labels
=
self
[
"
cell_label
"
][
self
[
"
timepoint
"
]
==
timepoint
]
traps
=
self
[
"
trap
"
][
self
[
"
timepoint
"
]
==
timepoint
]
return
self
.
group_by_traps
(
traps
,
labels
)
class
CellsMat
(
Cells
):
def
__init__
(
self
,
mat_object
):
super
(
CellsMat
,
self
).
__init__
()
# TODO add __contains__ to the matObject
timelapse_traps
=
mat_object
.
get
(
"
timelapseTrapsOmero
"
,
mat_object
.
get
(
"
timelapseTraps
"
,
None
)
)
if
timelapse_traps
is
None
:
raise
NotImplementedError
(
"
Could not find a timelapseTraps or
"
"
timelapseTrapsOmero object. Cells
"
"
from cellResults not implemented
"
)
else
:
self
.
trap_info
=
timelapse_traps
[
"
cTimepoint
"
][
"
trapInfo
"
]
if
isinstance
(
self
.
trap_info
,
list
):
self
.
trap_info
=
{
k
:
list
([
res
.
get
(
k
,
[])
for
res
in
self
.
trap_info
])
for
k
in
self
.
trap_info
[
0
].
keys
()
}
def
where
(
self
,
cell_id
,
trap_id
):
times
,
indices
=
zip
(
*
[
(
tp
,
np
.
where
(
cell_id
==
x
)[
0
][
0
])
for
tp
,
x
in
enumerate
(
self
.
trap_info
[
"
cellLabel
"
][:,
trap_id
].
tolist
())
if
np
.
any
(
cell_id
==
x
)
]
)
return
times
,
indices
def
outline
(
self
,
cell_id
,
trap_id
):
times
,
indices
=
self
.
where
(
cell_id
,
trap_id
)
info
=
self
.
trap_info
[
"
cell
"
][
times
,
trap_id
]
def
get_segmented
(
cell
,
index
):
if
cell
[
"
segmented
"
].
ndim
==
0
:
return
cell
[
"
segmented
"
][()].
todense
()
else
:
return
cell
[
"
segmented
"
][
index
].
todense
()
segmentation_outline
=
[
get_segmented
(
cell
,
idx
)
for
idx
,
cell
in
zip
(
indices
,
info
)
]
return
times
,
np
.
array
(
segmentation_outline
)
def
mask
(
self
,
cell_id
,
trap_id
):
times
,
outlines
=
self
.
outline
(
cell_id
,
trap_id
)
return
times
,
np
.
array
(
[
ndimage
.
morphology
.
binary_fill_holes
(
o
)
for
o
in
outlines
]
)
def
at_time
(
self
,
timepoint
,
kind
=
"
outline
"
):
"""
Returns the segmentations for all the cells at a given timepoint.
FIXME: this is extremely hacky and accounts for differently saved
results in the matlab object. Deprecate ASAP.
"""
# Case 1: only one cell per trap: trap_info['cell'][timepoint] is a
# structured array
if
isinstance
(
self
.
trap_info
[
"
cell
"
][
timepoint
],
dict
):
segmentations
=
[
self
.
_astype
(
x
,
"
outline
"
)
for
x
in
self
.
trap_info
[
"
cell
"
][
timepoint
][
"
segmented
"
]
]
# Case 2: Multiple cells per trap: it becomes a list of arrays or
# dictionaries, one for each trap
# Case 2.1 : it's a dictionary
elif
isinstance
(
self
.
trap_info
[
"
cell
"
][
timepoint
][
0
],
dict
):
segmentations
=
[]
for
x
in
self
.
trap_info
[
"
cell
"
][
timepoint
]:
seg
=
x
[
"
segmented
"
]
if
not
isinstance
(
seg
,
np
.
ndarray
):
seg
=
[
seg
]
segmentations
.
append
([
self
.
_astype
(
y
,
"
outline
"
)
for
y
in
seg
])
# Case 2.2 : it's an array
else
:
segmentations
=
[
[
self
.
_astype
(
y
,
type
)
for
y
in
x
[
"
segmented
"
]]
if
x
.
ndim
!=
0
else
[]
for
x
in
self
.
trap_info
[
"
cell
"
][
timepoint
]
]
# Return dict for compatibility with hdf5 output
return
{
i
:
v
for
i
,
v
in
enumerate
(
segmentations
)}
def
labels_at_time
(
self
,
tp
):
labels
=
self
.
trap_info
[
"
cellLabel
"
]
labels
=
[
_aslist
(
x
)
for
x
in
labels
[
tp
]]
labels
=
{
i
:
[
lbl
for
lbl
in
lblset
]
for
i
,
lblset
in
enumerate
(
labels
)}
return
labels
@property
def
ntraps
(
self
):
return
len
(
self
.
trap_info
[
"
cellLabel
"
][
0
])
@property
def
tile_size
(
self
):
pass
class
ExtractionRunner
:
"""
An object to run extraction of fluorescence, and general data out of
segmented data.
Configure with what extraction we want to run.
Cell selection criteria.
Filtering criteria.
"""
def
__init__
(
self
,
tiler
,
cells
):
pass
def
run
(
self
,
keys
,
store
,
**
kwargs
):
pass
def
_aslist
(
x
):
if
isinstance
(
x
,
Iterable
):
if
hasattr
(
x
,
"
tolist
"
):
x
=
x
.
tolist
()
else
:
x
=
[
x
]
return
x
This diff is collapsed.
Click to expand it.
tile/tiler.py
0 → 100644
+
333
−
0
View file @
17b9d3e5
"""
Segment/segmented pipelines.
Includes splitting the image into traps/parts,
cell segmentation, nucleus segmentation.
"""
import
warnings
from
functools
import
lru_cache
import
h5py
import
numpy
as
np
from
pathlib
import
Path
,
PosixPath
from
skimage.registration
import
phase_cross_correlation
from
agora.abc
import
ParametersABC
,
ProcessABC
from
aliby.traps
import
segment_traps
from
agora.io.writer
import
load_attributes
trap_template_directory
=
Path
(
__file__
).
parent
/
"
trap_templates
"
# TODO do we need multiple templates, one for each setup?
trap_template
=
np
.
array
([])
# np.load(trap_template_directory / "trap_prime.npy")
def
get_tile_shapes
(
x
,
tile_size
,
max_shape
):
half_size
=
tile_size
//
2
xmin
=
int
(
x
[
0
]
-
half_size
)
ymin
=
max
(
0
,
int
(
x
[
1
]
-
half_size
))
if
xmin
+
tile_size
>
max_shape
[
0
]:
xmin
=
max_shape
[
0
]
-
tile_size
if
ymin
+
tile_size
>
max_shape
[
1
]:
ymin
=
max_shape
[
1
]
-
tile_size
return
xmin
,
xmin
+
tile_size
,
ymin
,
ymin
+
tile_size
###################### Dask versions ########################
class
Trap
:
def
__init__
(
self
,
centre
,
parent
,
size
,
max_size
):
self
.
centre
=
centre
self
.
parent
=
parent
# Used to access drifts
self
.
size
=
size
self
.
half_size
=
size
//
2
self
.
max_size
=
max_size
def
padding_required
(
self
,
tp
):
"""
Check if we need to pad the trap image for this time point.
"""
try
:
assert
all
(
self
.
at_time
(
tp
)
-
self
.
half_size
>=
0
)
assert
all
(
self
.
at_time
(
tp
)
+
self
.
half_size
<=
self
.
max_size
)
except
AssertionError
:
return
True
return
False
def
at_time
(
self
,
tp
):
"""
Return trap centre at time tp
"""
drifts
=
self
.
parent
.
drifts
return
self
.
centre
-
np
.
sum
(
drifts
[:
tp
],
axis
=
0
)
def
as_tile
(
self
,
tp
):
"""
Return trap in the OMERO tile format of x, y, w, h
Also returns the padding necessary for this tile.
"""
x
,
y
=
self
.
at_time
(
tp
)
# tile bottom corner
x
=
int
(
x
-
self
.
half_size
)
y
=
int
(
y
-
self
.
half_size
)
return
x
,
y
,
self
.
size
,
self
.
size
def
as_range
(
self
,
tp
):
"""
Return trap in a range format, two slice objects that can be used in Arrays
"""
x
,
y
,
w
,
h
=
self
.
as_tile
(
tp
)
return
slice
(
x
,
x
+
w
),
slice
(
y
,
y
+
h
)
class
TrapLocations
:
def
__init__
(
self
,
initial_location
,
tile_size
,
max_size
=
1200
,
drifts
=
[]):
self
.
tile_size
=
tile_size
self
.
max_size
=
max_size
self
.
initial_location
=
initial_location
self
.
traps
=
[
Trap
(
centre
,
self
,
tile_size
,
max_size
)
for
centre
in
initial_location
]
self
.
drifts
=
drifts
@classmethod
def
from_source
(
cls
,
fpath
:
str
):
with
h5py
.
File
(
fpath
,
"
r
"
)
as
f
:
# TODO read tile size from file metadata
drifts
=
f
[
"
trap_info/drifts
"
][()]
tlocs
=
cls
(
f
[
"
trap_info/trap_locations
"
][()],
tile_size
=
96
,
drifts
=
drifts
)
return
tlocs
@property
def
shape
(
self
):
return
len
(
self
.
traps
),
len
(
self
.
drifts
)
def
__len__
(
self
):
return
len
(
self
.
traps
)
def
__iter__
(
self
):
yield
from
self
.
traps
def
padding_required
(
self
,
tp
):
return
any
([
trap
.
padding_required
(
tp
)
for
trap
in
self
.
traps
])
def
to_dict
(
self
,
tp
):
res
=
dict
()
if
tp
==
0
:
res
[
"
trap_locations
"
]
=
self
.
initial_location
res
[
"
attrs/tile_size
"
]
=
self
.
tile_size
res
[
"
attrs/max_size
"
]
=
self
.
max_size
res
[
"
drifts
"
]
=
np
.
expand_dims
(
self
.
drifts
[
tp
],
axis
=
0
)
# res['processed_timepoints'] = tp
return
res
@classmethod
def
read_hdf5
(
cls
,
file
):
with
h5py
.
File
(
file
,
"
r
"
)
as
hfile
:
trap_info
=
hfile
[
"
trap_info
"
]
initial_locations
=
trap_info
[
"
trap_locations
"
][()]
drifts
=
trap_info
[
"
drifts
"
][()]
max_size
=
trap_info
.
attrs
[
"
max_size
"
]
tile_size
=
trap_info
.
attrs
[
"
tile_size
"
]
trap_locs
=
cls
(
initial_locations
,
tile_size
,
max_size
=
max_size
)
trap_locs
.
drifts
=
drifts
return
trap_locs
class
TilerParameters
(
ParametersABC
):
def
__init__
(
self
,
tile_size
:
int
,
ref_channel
:
str
,
ref_z
:
int
,
template_name
:
str
=
None
):
self
.
tile_size
=
tile_size
self
.
ref_channel
=
ref_channel
self
.
ref_z
=
ref_z
self
.
template_name
=
template_name
@classmethod
def
from_template
(
cls
,
template_name
:
str
,
ref_channel
:
str
,
ref_z
:
int
):
return
cls
(
template
.
shape
[
0
],
ref_channel
,
ref_z
,
template_path
=
template_name
)
@classmethod
def
default
(
cls
):
return
cls
(
96
,
"
Brightfield
"
,
0
)
class
Tiler
(
ProcessABC
):
"""
A dummy TimelapseTiler object fora Dask Demo.
Does trap finding and image registration.
"""
def
__init__
(
self
,
image
,
metadata
,
parameters
:
TilerParameters
,
):
super
().
__init__
(
parameters
)
self
.
image
=
image
self
.
channels
=
metadata
[
"
channels
"
]
self
.
ref_channel
=
self
.
get_channel_index
(
parameters
.
ref_channel
)
@classmethod
def
from_image
(
cls
,
image
,
parameters
:
TilerParameters
):
return
cls
(
image
.
data
,
image
.
metadata
,
parameters
)
@classmethod
def
from_hdf5
(
cls
,
image
,
filepath
,
tile_size
=
None
):
trap_locs
=
TrapLocations
.
read_hdf5
(
filepath
)
metadata
=
load_attributes
(
filepath
)
metadata
[
"
channels
"
]
=
metadata
[
"
channels/channel
"
].
tolist
()
if
tile_size
is
None
:
tile_size
=
trap_locs
.
tile_size
return
Tiler
(
image
=
image
,
metadata
=
metadata
,
template
=
None
,
tile_size
=
tile_size
,
trap_locs
=
trap_locs
,
)
@lru_cache
(
maxsize
=
2
)
def
get_tc
(
self
,
t
,
c
):
# Get image
full
=
self
.
image
[
t
,
c
].
compute
()
# FORCE THE CACHE
return
full
@property
def
shape
(
self
):
c
,
t
,
z
,
y
,
x
=
self
.
image
.
shape
return
(
c
,
t
,
x
,
y
,
z
)
@property
def
n_processed
(
self
):
if
not
hasattr
(
self
,
"
_n_processed
"
):
self
.
_n_processed
=
0
return
self
.
_n_processed
@n_processed.setter
def
n_processed
(
self
,
value
):
self
.
_n_processed
=
value
@property
def
n_traps
(
self
):
return
len
(
self
.
trap_locs
)
@property
def
finished
(
self
):
return
self
.
n_processed
==
self
.
image
.
shape
[
0
]
def
_initialise_traps
(
self
,
tile_size
):
"""
Find initial trap positions.
Removes all those that are too close to the edge so no padding is necessary.
"""
half_tile
=
tile_size
//
2
max_size
=
min
(
self
.
image
.
shape
[
-
2
:])
initial_image
=
self
.
image
[
0
,
self
.
ref_channel
,
self
.
ref_z
]
# First time point, first channel, first z-position
trap_locs
=
segment_traps
(
initial_image
,
tile_size
)
trap_locs
=
[
[
x
,
y
]
for
x
,
y
in
trap_locs
if
half_tile
<
x
<
max_size
-
half_tile
and
half_tile
<
y
<
max_size
-
half_tile
]
self
.
trap_locs
=
TrapLocations
(
trap_locs
,
tile_size
)
def
find_drift
(
self
,
tp
):
# TODO check that the drift doesn't move any tiles out of the image, remove them from list if so
prev_tp
=
max
(
0
,
tp
-
1
)
drift
,
error
,
_
=
phase_cross_correlation
(
self
.
image
[
prev_tp
,
self
.
ref_channel
,
self
.
ref_z
],
self
.
image
[
tp
,
self
.
ref_channel
,
self
.
ref_z
],
)
self
.
trap_locs
.
drifts
.
append
(
drift
)
def
get_tp_data
(
self
,
tp
,
c
):
traps
=
[]
full
=
self
.
get_tc
(
tp
,
c
)
# if self.trap_locs.padding_required(tp):
for
trap
in
self
.
trap_locs
:
ndtrap
=
self
.
ifoob_pad
(
full
,
trap
.
as_range
(
tp
))
traps
.
append
(
ndtrap
)
return
np
.
stack
(
traps
)
def
get_trap_data
(
self
,
trap_id
,
tp
,
c
):
full
=
self
.
get_tc
(
tp
,
c
)
trap
=
self
.
trap_locs
.
traps
[
trap_id
]
ndtrap
=
self
.
ifoob_pad
(
full
,
trap
.
as_range
(
tp
))
return
ndtrap
@staticmethod
def
ifoob_pad
(
full
,
slices
):
"""
Returns the slices padded if it is out of bounds
Parameters:
----------
full: (zstacks, max_size, max_size) ndarray
Entire position with zstacks as first axis
slices: tuple of two slices
Each slice indicates an axis to index
Returns
Trap for given slices, padded with median if needed, or np.nan if the padding is too much
"""
max_size
=
full
.
shape
[
-
1
]
y
,
x
=
[
slice
(
max
(
0
,
s
.
start
),
min
(
max_size
,
s
.
stop
))
for
s
in
slices
]
trap
=
full
[:,
y
,
x
]
padding
=
np
.
array
(
[(
-
min
(
0
,
s
.
start
),
-
min
(
0
,
max_size
-
s
.
stop
))
for
s
in
slices
]
)
if
padding
.
any
():
tile_size
=
slices
[
0
].
stop
-
slices
[
0
].
start
if
(
padding
>
tile_size
/
4
).
any
():
trap
=
np
.
full
((
full
.
shape
[
0
],
tile_size
,
tile_size
),
np
.
nan
)
else
:
trap
=
np
.
pad
(
trap
,
[[
0
,
0
]]
+
padding
.
tolist
(),
"
median
"
)
return
trap
def
run_tp
(
self
,
tp
):
assert
tp
>=
self
.
n_processed
,
"
Time point already processed
"
# TODO check contiguity?
if
self
.
n_processed
==
0
:
self
.
_initialise_traps
(
self
.
tile_size
)
self
.
find_drift
(
tp
)
# Get drift
# update n_processed
self
.
n_processed
+=
1
# Return result for writer
return
self
.
trap_locs
.
to_dict
(
tp
)
def
run
(
self
,
tp
):
if
self
.
n_processed
==
0
:
self
.
_initialise_traps
(
self
.
tile_size
)
self
.
find_drift
(
tp
)
# Get drift
# update n_processed
self
.
n_processed
+=
1
# Return result for writer
return
self
.
trap_locs
.
to_dict
(
tp
)
# The next set of functions are necessary for the extraction object
def
get_traps_timepoint
(
self
,
tp
,
tile_size
=
None
,
channels
=
None
,
z
=
None
):
# FIXME we currently ignore the tile size
# FIXME can we ignore z(always give)
res
=
[]
for
c
in
channels
:
val
=
self
.
get_tp_data
(
tp
,
c
)[:,
z
]
# Only return requested z
# positions
# Starts at traps, z, y, x
# Turn to Trap, C, T, X, Y, Z order
val
=
val
.
swapaxes
(
1
,
3
).
swapaxes
(
1
,
2
)
val
=
np
.
expand_dims
(
val
,
axis
=
1
)
res
.
append
(
val
)
return
np
.
stack
(
res
,
axis
=
1
)
def
get_channel_index
(
self
,
item
):
for
i
,
ch
in
enumerate
(
self
.
channels
):
if
item
in
ch
:
return
i
def
get_position_annotation
(
self
):
# TODO required for matlab support
return
None
This diff is collapsed.
Click to expand it.
tile/traps.py
0 → 100644
+
480
−
0
View file @
17b9d3e5
"""
A set of utilities for dealing with ALCATRAS traps
"""
import
numpy
as
np
from
tqdm
import
tqdm
from
skimage
import
transform
,
feature
from
skimage.filters.rank
import
entropy
from
skimage.filters
import
threshold_otsu
from
skimage.segmentation
import
clear_border
from
skimage.measure
import
label
,
regionprops
from
skimage.morphology
import
disk
,
closing
,
square
def
stretch_image
(
image
):
image
=
((
image
-
image
.
min
())
/
(
image
.
max
()
-
image
.
min
()))
*
255
minval
=
np
.
percentile
(
image
,
2
)
maxval
=
np
.
percentile
(
image
,
98
)
image
=
np
.
clip
(
image
,
minval
,
maxval
)
image
=
(
image
-
minval
)
/
(
maxval
-
minval
)
return
image
def
segment_traps
(
image
,
tile_size
,
downscale
=
0.4
):
# Make image go between 0 and 255
img
=
image
# Keep a memory of image in case need to re-run
# stretched = stretch_image(image)
# img = stretch_image(image)
# TODO Optimise the hyperparameters
disk_radius
=
int
(
min
([
0.01
*
x
for
x
in
img
.
shape
]))
min_area
=
0.2
*
(
tile_size
**
2
)
if
downscale
!=
1
:
img
=
transform
.
rescale
(
image
,
downscale
)
entropy_image
=
entropy
(
img
,
disk
(
disk_radius
))
if
downscale
!=
1
:
entropy_image
=
transform
.
rescale
(
entropy_image
,
1
/
downscale
)
# apply threshold
thresh
=
threshold_otsu
(
entropy_image
)
bw
=
closing
(
entropy_image
>
thresh
,
square
(
3
))
# remove artifacts connected to image border
cleared
=
clear_border
(
bw
)
# label image regions
label_image
=
label
(
cleared
)
areas
=
[
region
.
area
for
region
in
regionprops
(
label_image
)
if
region
.
area
>
min_area
and
region
.
area
<
tile_size
**
2
*
0.8
]
traps
=
(
np
.
array
(
[
region
.
centroid
for
region
in
regionprops
(
label_image
)
if
region
.
area
>
min_area
and
region
.
area
<
tile_size
**
2
*
0.8
]
)
.
round
()
.
astype
(
int
)
)
ma
=
(
np
.
array
(
[
region
.
minor_axis_length
for
region
in
regionprops
(
label_image
)
if
region
.
area
>
min_area
and
region
.
area
<
tile_size
**
2
*
0.8
]
)
.
round
()
.
astype
(
int
)
)
maskx
=
(
tile_size
//
2
<
traps
[:,
0
])
&
(
traps
[:,
0
]
<
image
.
shape
[
0
]
-
tile_size
//
2
)
masky
=
(
tile_size
//
2
<
traps
[:,
1
])
&
(
traps
[:,
1
]
<
image
.
shape
[
1
]
-
tile_size
//
2
)
traps
=
traps
[
maskx
&
masky
,
:]
ma
=
ma
[
maskx
&
masky
]
chosen_trap_coords
=
np
.
round
(
traps
[
ma
.
argmin
()]).
astype
(
int
)
x
,
y
=
chosen_trap_coords
template
=
image
[
x
-
tile_size
//
2
:
x
+
tile_size
//
2
,
y
-
tile_size
//
2
:
y
+
tile_size
//
2
]
traps
=
identify_trap_locations
(
image
,
template
)
if
len
(
traps
)
<
10
and
downscale
!=
1
:
print
(
"
Trying again.
"
)
return
segment_traps
(
image
,
tile_size
,
downscale
=
1
)
return
traps
# def segment_traps(image, tile_size, downscale=0.4):
# # Make image go between 0 and 255
# img = image # Keep a memory of image in case need to re-run
# image = stretch_image(image)
# # TODO Optimise the hyperparameters
# disk_radius = int(min([0.01 * x for x in img.shape]))
# min_area = 0.1 * (tile_size ** 2)
# if downscale != 1:
# img = transform.rescale(image, downscale)
# entropy_image = entropy(img, disk(disk_radius))
# if downscale != 1:
# entropy_image = transform.rescale(entropy_image, 1 / downscale)
# # apply threshold
# thresh = threshold_otsu(entropy_image)
# bw = closing(entropy_image > thresh, square(3))
# # remove artifacts connected to image border
# cleared = clear_border(bw)
# # label image regions
# label_image = label(cleared)
# traps = [
# region.centroid for region in regionprops(label_image) if region.area > min_area
# ]
# if len(traps) < 10 and downscale != 1:
# print("Trying again.")
# return segment_traps(image, tile_size, downscale=1)
# return traps
def
identify_trap_locations
(
image
,
trap_template
,
optimize_scale
=
True
,
downscale
=
0.35
,
trap_size
=
None
):
"""
Identify the traps in a single image based on a trap template.
This assumes a trap template that is similar to the image in question
(same camera, same magification; ideally same experiment).
This method speeds up the search by downscaling both the image and
the trap template before running the template match.
It also optimizes the scale and the rotation of the trap template.
:param image:
:param trap_template:
:param optimize_scale:
:param downscale:
:param trap_rotation:
:return:
"""
trap_size
=
trap_size
if
trap_size
is
not
None
else
trap_template
.
shape
[
0
]
# Careful, the image is float16!
img
=
transform
.
rescale
(
image
.
astype
(
float
),
downscale
)
temp
=
transform
.
rescale
(
trap_template
,
downscale
)
# TODO random search hyperparameter optimization
# optimize rotation
matches
=
{
rotation
:
feature
.
match_template
(
img
,
transform
.
rotate
(
temp
,
rotation
,
cval
=
np
.
median
(
img
)),
pad_input
=
True
,
mode
=
"
median
"
,
)
**
2
for
rotation
in
[
0
,
90
,
180
,
270
]
}
best_rotation
=
max
(
matches
,
key
=
lambda
x
:
np
.
percentile
(
matches
[
x
],
99.9
))
temp
=
transform
.
rotate
(
temp
,
best_rotation
,
cval
=
np
.
median
(
img
))
if
optimize_scale
:
scales
=
np
.
linspace
(
0.5
,
2
,
10
)
matches
=
{
scale
:
feature
.
match_template
(
img
,
transform
.
rescale
(
temp
,
scale
),
mode
=
"
median
"
,
pad_input
=
True
)
**
2
for
scale
in
scales
}
best_scale
=
max
(
matches
,
key
=
lambda
x
:
np
.
percentile
(
matches
[
x
],
99.9
))
matched
=
matches
[
best_scale
]
else
:
matched
=
feature
.
match_template
(
img
,
temp
,
pad_input
=
True
,
mode
=
"
median
"
)
coordinates
=
feature
.
peak_local_max
(
transform
.
rescale
(
matched
,
1
/
downscale
),
min_distance
=
int
(
trap_template
.
shape
[
0
]
*
0.70
),
exclude_border
=
(
trap_size
//
3
),
)
return
coordinates
def
get_tile_shapes
(
x
,
tile_size
,
max_shape
):
half_size
=
tile_size
//
2
xmin
=
int
(
x
[
0
]
-
half_size
)
ymin
=
max
(
0
,
int
(
x
[
1
]
-
half_size
))
# if xmin + tile_size > max_shape[0]:
# xmin = max_shape[0] - tile_size
# if ymin + tile_size > max_shape[1]:
# # ymin = max_shape[1] - tile_size
# return max(xmin, 0), xmin + tile_size, max(ymin, 0), ymin + tile_size
return
xmin
,
xmin
+
tile_size
,
ymin
,
ymin
+
tile_size
def
in_image
(
img
,
xmin
,
xmax
,
ymin
,
ymax
,
xidx
=
2
,
yidx
=
3
):
if
xmin
>=
0
and
ymin
>=
0
:
if
xmax
<
img
.
shape
[
xidx
]
and
ymax
<
img
.
shape
[
yidx
]:
return
True
else
:
return
False
def
get_xy_tile
(
img
,
xmin
,
xmax
,
ymin
,
ymax
,
xidx
=
2
,
yidx
=
3
,
pad_val
=
None
):
if
pad_val
is
None
:
pad_val
=
np
.
median
(
img
)
# Get the tile from the image
idx
=
[
slice
(
None
)]
*
len
(
img
.
shape
)
idx
[
xidx
]
=
slice
(
max
(
0
,
xmin
),
min
(
xmax
,
img
.
shape
[
xidx
]))
idx
[
yidx
]
=
slice
(
max
(
0
,
ymin
),
min
(
ymax
,
img
.
shape
[
yidx
]))
tile
=
img
[
tuple
(
idx
)]
# Check if the tile is in the image
if
in_image
(
img
,
xmin
,
xmax
,
ymin
,
ymax
,
xidx
,
yidx
):
return
tile
else
:
# Add padding
pad_shape
=
[(
0
,
0
)]
*
len
(
img
.
shape
)
pad_shape
[
xidx
]
=
(
max
(
-
xmin
,
0
),
max
(
xmax
-
img
.
shape
[
xidx
],
0
))
pad_shape
[
yidx
]
=
(
max
(
-
ymin
,
0
),
max
(
ymax
-
img
.
shape
[
yidx
],
0
))
tile
=
np
.
pad
(
tile
,
pad_shape
,
constant_values
=
pad_val
)
return
tile
def
get_trap_timelapse
(
raw_expt
,
trap_locations
,
trap_id
,
tile_size
=
117
,
channels
=
None
,
z
=
None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels
=
channels
if
channels
is
not
None
else
[
0
]
z
=
z
if
z
is
not
None
else
[
0
]
# Get trap location for that id:
trap_centers
=
[
trap_locations
[
i
][
trap_id
]
for
i
in
range
(
len
(
trap_locations
))]
max_shape
=
(
raw_expt
.
shape
[
2
],
raw_expt
.
shape
[
3
])
tiles_shapes
=
[
get_tile_shapes
((
x
[
0
],
x
[
1
]),
tile_size
,
max_shape
)
for
x
in
trap_centers
]
timelapse
=
[
get_xy_tile
(
raw_expt
[
channels
,
i
,
:,
:,
z
],
xmin
,
xmax
,
ymin
,
ymax
,
pad_val
=
None
)
for
i
,
(
xmin
,
xmax
,
ymin
,
ymax
)
in
enumerate
(
tiles_shapes
)
]
return
np
.
hstack
(
timelapse
)
def
get_trap_timelapse_omero
(
raw_expt
,
trap_locations
,
trap_id
,
tile_size
=
117
,
channels
=
None
,
z
=
None
,
t
=
None
):
"""
Get a timelapse for a given trap by specifying the trap_id
:param raw_expt: A Timelapse object from which data is obtained
:param trap_id: An integer defining which trap to choose. Counted
between 0 and Tiler.n_traps - 1
:param tile_size: The size of the trap tile (centered around the
trap as much as possible, edge cases exist)
:param channels: Which channels to fetch, indexed from 0.
If None, defaults to [0]
:param z: Which z_stacks to fetch, indexed from 0.
If None, defaults to [0].
:return: A numpy array with the timelapse in (C,T,X,Y,Z) order
"""
# Set the defaults (list is mutable)
channels
=
channels
if
channels
is
not
None
else
[
0
]
z_positions
=
z
if
z
is
not
None
else
[
0
]
times
=
(
t
if
t
is
not
None
else
np
.
arange
(
raw_expt
.
shape
[
1
])
)
# TODO choose sub-set of time points
shape
=
(
len
(
channels
),
len
(
times
),
tile_size
,
tile_size
,
len
(
z_positions
))
# Get trap location for that id:
zct_tiles
,
slices
,
trap_ids
=
all_tiles
(
trap_locations
,
shape
,
raw_expt
,
z_positions
,
channels
,
times
,
[
trap_id
]
)
# TODO Make this an explicit function in TimelapseOMERO
images
=
raw_expt
.
pixels
.
getTiles
(
zct_tiles
)
timelapse
=
np
.
full
(
shape
,
np
.
nan
)
total
=
len
(
zct_tiles
)
for
(
z
,
c
,
t
,
_
),
(
y
,
x
),
image
in
tqdm
(
zip
(
zct_tiles
,
slices
,
images
),
total
=
total
):
ch
=
channels
.
index
(
c
)
tp
=
times
.
tolist
().
index
(
t
)
z_pos
=
z_positions
.
index
(
z
)
timelapse
[
ch
,
tp
,
x
[
0
]
:
x
[
1
],
y
[
0
]
:
y
[
1
],
z_pos
]
=
image
# for x in timelapse: # By channel
# np.nan_to_num(x, nan=np.nanmedian(x), copy=False)
return
timelapse
def
all_tiles
(
trap_locations
,
shape
,
raw_expt
,
z_positions
,
channels
,
times
,
traps
):
_
,
_
,
x
,
y
,
_
=
shape
_
,
_
,
MAX_X
,
MAX_Y
,
_
=
raw_expt
.
shape
trap_ids
=
[]
zct_tiles
=
[]
slices
=
[]
for
z
in
z_positions
:
for
ch
in
channels
:
for
t
in
times
:
for
trap_id
in
traps
:
centre
=
trap_locations
[
t
][
trap_id
]
xmin
,
ymin
,
xmax
,
ymax
,
r_xmin
,
r_ymin
,
r_xmax
,
r_ymax
=
tile_where
(
centre
,
x
,
y
,
MAX_X
,
MAX_Y
)
slices
.
append
(
((
r_ymin
-
ymin
,
r_ymax
-
ymin
),
(
r_xmin
-
xmin
,
r_xmax
-
xmin
))
)
tile
=
(
r_ymin
,
r_xmin
,
r_ymax
-
r_ymin
,
r_xmax
-
r_xmin
)
zct_tiles
.
append
((
z
,
ch
,
t
,
tile
))
trap_ids
.
append
(
trap_id
)
# So we remember the order!
return
zct_tiles
,
slices
,
trap_ids
def
tile_where
(
centre
,
x
,
y
,
MAX_X
,
MAX_Y
):
# Find the position of the tile
xmin
=
int
(
centre
[
1
]
-
x
//
2
)
ymin
=
int
(
centre
[
0
]
-
y
//
2
)
xmax
=
xmin
+
x
ymax
=
ymin
+
y
# What do we actually have available?
r_xmin
=
max
(
0
,
xmin
)
r_xmax
=
min
(
MAX_X
,
xmax
)
r_ymin
=
max
(
0
,
ymin
)
r_ymax
=
min
(
MAX_Y
,
ymax
)
return
xmin
,
ymin
,
xmax
,
ymax
,
r_xmin
,
r_ymin
,
r_xmax
,
r_ymax
def
get_tile
(
shape
,
center
,
raw_expt
,
ch
,
t
,
z
):
"""
Returns a tile from the raw experiment with a given shape.
:param shape: The shape of the tile in (C, T, Z, Y, X) order.
:param center: The x,y position of the centre of the tile
:param
"""
_
,
_
,
x
,
y
,
_
=
shape
_
,
_
,
MAX_X
,
MAX_Y
,
_
=
raw_expt
.
shape
tile
=
np
.
full
(
shape
,
np
.
nan
)
# Find the position of the tile
xmin
=
int
(
center
[
1
]
-
x
//
2
)
ymin
=
int
(
center
[
0
]
-
y
//
2
)
xmax
=
xmin
+
x
ymax
=
ymin
+
y
# What do we actually have available?
r_xmin
=
max
(
0
,
xmin
)
r_xmax
=
min
(
MAX_X
,
xmax
)
r_ymin
=
max
(
0
,
ymin
)
r_ymax
=
min
(
MAX_Y
,
ymax
)
# Fill values
tile
[
:,
:,
(
r_xmin
-
xmin
)
:
(
r_xmax
-
xmin
),
(
r_ymin
-
ymin
)
:
(
r_ymax
-
ymin
),
:
]
=
raw_expt
[
ch
,
t
,
r_xmin
:
r_xmax
,
r_ymin
:
r_ymax
,
z
]
# fill_val = np.nanmedian(tile)
# np.nan_to_num(tile, nan=fill_val, copy=False)
return
tile
def
get_traps_timepoint
(
raw_expt
,
trap_locations
,
tp
,
tile_size
=
96
,
channels
=
None
,
z
=
None
):
"""
Get all the traps from a given time point
:param raw_expt:
:param trap_locations:
:param tp:
:param tile_size:
:param channels:
:param z:
:return: A numpy array with the traps in the (trap, C, T, X, Y,
Z) order
"""
# Set the defaults (list is mutable)
channels
=
channels
if
channels
is
not
None
else
[
0
]
z_positions
=
z
if
z
is
not
None
else
[
0
]
if
isinstance
(
z_positions
,
slice
):
n_z
=
z_positions
.
stop
z_positions
=
list
(
range
(
n_z
))
# slice is not iterable error
elif
isinstance
(
z_positions
,
list
):
n_z
=
len
(
z_positions
)
else
:
n_z
=
1
n_traps
=
len
(
trap_locations
[
tp
])
trap_ids
=
list
(
range
(
n_traps
))
shape
=
(
len
(
channels
),
1
,
tile_size
,
tile_size
,
n_z
)
# all tiles
zct_tiles
,
slices
,
trap_ids
=
all_tiles
(
trap_locations
,
shape
,
raw_expt
,
z_positions
,
channels
,
[
tp
],
trap_ids
)
# TODO Make this an explicit function in TimelapseOMERO
images
=
raw_expt
.
pixels
.
getTiles
(
zct_tiles
)
# Initialise empty traps
traps
=
np
.
full
((
n_traps
,)
+
shape
,
np
.
nan
)
for
trap_id
,
(
z
,
c
,
_
,
_
),
(
y
,
x
),
image
in
zip
(
trap_ids
,
zct_tiles
,
slices
,
images
):
ch
=
channels
.
index
(
c
)
z_pos
=
z_positions
.
index
(
z
)
traps
[
trap_id
,
ch
,
0
,
x
[
0
]
:
x
[
1
],
y
[
0
]
:
y
[
1
],
z_pos
]
=
image
for
x
in
traps
:
# By channel
np
.
nan_to_num
(
x
,
nan
=
np
.
nanmedian
(
x
),
copy
=
False
)
return
traps
def
centre
(
img
,
percentage
=
0.3
):
y
,
x
=
img
.
shape
cropx
=
int
(
np
.
ceil
(
x
*
percentage
))
cropy
=
int
(
np
.
ceil
(
y
*
percentage
))
startx
=
int
(
x
//
2
-
(
cropx
//
2
))
starty
=
int
(
y
//
2
-
(
cropy
//
2
))
return
img
[
starty
:
starty
+
cropy
,
startx
:
startx
+
cropx
]
def
align_timelapse_images
(
raw_data
,
channel
=
0
,
reference_reset_time
=
80
,
reference_reset_drift
=
25
):
"""
Uses image registration to align images in the timelapse.
Uses the channel with id `channel` to perform the registration.
Starts with the first timepoint as a reference and changes the
reference to the current timepoint if either the images have moved
by half of a trap width or `reference_reset_time` has been reached.
Sets `self.drift`, a 3D numpy array with shape (t, drift_x, drift_y).
We assume no drift occurs in the z-direction.
:param reference_reset_drift: Upper bound on the allowed drift before
resetting the reference image.
:param reference_reset_time: Upper bound on number of time points to
register before resetting the reference image.
:param channel: index of the channel to use for image registration.
"""
ref
=
centre
(
np
.
squeeze
(
raw_data
[
channel
,
0
,
:,
:,
0
]))
size_t
=
raw_data
.
shape
[
1
]
drift
=
[
np
.
array
([
0
,
0
])]
for
i
in
range
(
1
,
size_t
):
img
=
centre
(
np
.
squeeze
(
raw_data
[
channel
,
i
,
:,
:,
0
]))
shifts
,
_
,
_
=
feature
.
register_translation
(
ref
,
img
)
# If a huge move is detected at a single time point it is taken
# to be inaccurate and the correction from the previous time point
# is used.
# This might be common if there is a focus loss for example.
if
any
([
abs
(
x
-
y
)
>
reference_reset_drift
for
x
,
y
in
zip
(
shifts
,
drift
[
-
1
])]):
shifts
=
drift
[
-
1
]
drift
.
append
(
shifts
)
ref
=
img
# TODO test necessity for references, description below
# If the images have drifted too far from the reference or too
# much time has passed we change the reference and keep track of
# which images are kept as references
return
np
.
stack
(
drift
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment