Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
alibylite
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Swain Lab
aliby
alibylite
Commits
0d9f8df5
Commit
0d9f8df5
authored
3 years ago
by
Alán Muñoz
Browse files
Options
Downloads
Patches
Plain Diff
improve distributed and tiler
parent
50d77d41
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
scripts/distributed.py
+178
-95
178 additions, 95 deletions
scripts/distributed.py
scripts/improve_tiler.py
+0
-187
0 additions, 187 deletions
scripts/improve_tiler.py
with
178 additions
and
282 deletions
scripts/distributed.py
+
178
−
95
View file @
0d9f8df5
...
@@ -7,8 +7,9 @@ from core.experiment import MetaData
...
@@ -7,8 +7,9 @@ from core.experiment import MetaData
from
pathos.multiprocessing
import
Pool
from
pathos.multiprocessing
import
Pool
from
multiprocessing
import
set_start_method
from
multiprocessing
import
set_start_method
import
numpy
as
np
import
numpy
as
np
from
postprocessor.core.processor
import
PostProcessorParameters
,
PostProcessor
from
extraction.core.functions.defaults
import
exparams_from_meta
from
extraction.core.functions.defaults
import
exparams_from_meta
from
core.io.signal
import
Signal
# set_start_method("spawn")
# set_start_method("spawn")
...
@@ -31,10 +32,6 @@ from extraction.core.extractor import Extractor
...
@@ -31,10 +32,6 @@ from extraction.core.extractor import Extractor
from
extraction.core.parameters
import
Parameters
from
extraction.core.parameters
import
Parameters
from
extraction.core.functions.defaults
import
get_params
from
extraction.core.functions.defaults
import
get_params
import
warnings
# TODO This is for extraction issue #9, remove when fixed
warnings
.
simplefilter
(
'
ignore
'
,
RuntimeWarning
)
def
pipeline
(
image_id
,
tps
=
10
,
tf_version
=
2
):
def
pipeline
(
image_id
,
tps
=
10
,
tf_version
=
2
):
name
,
image_id
=
image_id
name
,
image_id
=
image_id
...
@@ -42,19 +39,19 @@ def pipeline(image_id, tps=10, tf_version=2):
...
@@ -42,19 +39,19 @@ def pipeline(image_id, tps=10, tf_version=2):
# Initialise tensorflow
# Initialise tensorflow
session
=
initialise_tf
(
tf_version
)
session
=
initialise_tf
(
tf_version
)
with
Image
(
image_id
)
as
image
:
with
Image
(
image_id
)
as
image
:
print
(
f
'
Getting data for
{
image
.
name
}
'
)
print
(
f
"
Getting data for
{
image
.
name
}
"
)
tiler
=
Tiler
(
image
.
data
,
image
.
metadata
,
image
.
name
)
tiler
=
Tiler
(
image
.
data
,
image
.
metadata
,
image
.
name
)
writer
=
TilerWriter
(
f
'
../data/test2/
{
image
.
name
}
.h5
'
)
writer
=
TilerWriter
(
f
"
../data/test2/
{
image
.
name
}
.h5
"
)
runner
=
DummyRunner
(
tiler
)
runner
=
DummyRunner
(
tiler
)
bwriter
=
BabyWriter
(
f
'
../data/test2/
{
image
.
name
}
.h5
'
)
bwriter
=
BabyWriter
(
f
"
../data/test2/
{
image
.
name
}
.h5
"
)
for
i
in
tqdm
(
range
(
0
,
tps
),
desc
=
image
.
name
):
for
i
in
tqdm
(
range
(
0
,
tps
),
desc
=
image
.
name
):
trap_info
=
tiler
.
run_tp
(
i
)
trap_info
=
tiler
.
run_tp
(
i
)
writer
.
write
(
trap_info
,
overwrite
=
[])
writer
.
write
(
trap_info
,
overwrite
=
[])
seg
=
runner
.
run_tp
(
i
)
seg
=
runner
.
run_tp
(
i
)
bwriter
.
write
(
seg
,
overwrite
=
[
'
mother_assign
'
])
bwriter
.
write
(
seg
,
overwrite
=
[
"
mother_assign
"
])
return
True
return
True
except
Exception
as
e
:
# bug in the trap getting
except
Exception
as
e
:
# bug in the trap getting
print
(
f
'
Caught exception in worker thread (x =
{
name
}
):
'
)
print
(
f
"
Caught exception in worker thread (x =
{
name
}
):
"
)
# This prints the type, value, and stack trace of the
# This prints the type, value, and stack trace of the
# current exception being handled.
# current exception being handled.
traceback
.
print_exc
()
traceback
.
print_exc
()
...
@@ -72,50 +69,95 @@ def create_pipeline(image_id, **config):
...
@@ -72,50 +69,95 @@ def create_pipeline(image_id, **config):
general_config
=
config
.
get
(
"
general
"
,
None
)
general_config
=
config
.
get
(
"
general
"
,
None
)
assert
general_config
is
not
None
assert
general_config
is
not
None
session
=
None
session
=
None
earlystop
=
config
.
get
(
"
earlystop
"
,
{
"
min_tp
"
:
50
,
"
thresh_pos_clogged
"
:
0.3
,
"
thresh_trap_clogged
"
:
7
,
"
ntps_to_eval
"
:
5
,
},
)
try
:
try
:
directory
=
general_config
.
get
(
"
directory
"
,
""
)
directory
=
general_config
.
get
(
"
directory
"
,
""
)
with
Image
(
image_id
)
as
image
:
with
Image
(
image_id
)
as
image
:
filename
=
f
"
{
directory
}
/
{
image
.
name
}
.h5
"
filename
=
f
"
{
directory
}
/
{
image
.
name
}
.h5
"
# Run metadata first
# Run metadata first
meta
=
MetaData
(
directory
,
filename
)
process_from
=
0
meta
.
run
()
if
True
:
# not Path(filename).exists():
tiler
=
Tiler
(
image
.
data
,
image
.
metadata
)
meta
=
MetaData
(
directory
,
filename
)
meta
.
run
()
tiler
=
Tiler
(
image
.
data
,
image
.
metadata
,
tile_size
=
general_config
.
get
(
"
tile_size
"
,
117
),
)
else
:
tiler
=
Tiler
.
from_hdf5
(
image
.
data
,
filename
)
s
=
Signal
(
filename
)
process_from
=
s
[
"
/general/None/extraction/volume
"
].
columns
[
-
1
]
if
process_from
>
2
:
process_from
=
process_from
-
3
tiler
.
n_processed
=
process_from
writer
=
TilerWriter
(
filename
)
writer
=
TilerWriter
(
filename
)
baby_config
=
config
.
get
(
"
baby
"
,
None
)
baby_config
=
config
.
get
(
"
baby
"
,
None
)
assert
baby_config
is
not
None
# TODO add defaults
assert
baby_config
is
not
None
# TODO add defaults
tf_version
=
baby_config
.
get
(
"
tf_version
"
,
1
)
tf_version
=
baby_config
.
get
(
"
tf_version
"
,
2
)
session
=
initialise_tf
(
tf_version
)
session
=
initialise_tf
(
tf_version
)
runner
=
DummyRunner
(
tiler
)
runner
=
DummyRunner
(
tiler
)
bwriter
=
BabyWriter
(
filename
)
bwriter
=
BabyWriter
(
filename
)
# FIXME testing here the extraction
params
=
Parameters
(
**
exparams_from_meta
(
filename
))
meta
=
load_attributes
(
filename
)
namebuild
=
[
meta
[
"
microscope
"
].
lower
(),
"
fast
"
]
if
"
mCherry
"
in
meta
[
"
channels/channel
"
]:
namebuild
.
insert
(
1
,
"
dual
"
)
params
=
Parameters
(
**
get_params
(
"
_
"
.
join
(
namebuild
)))
ext
=
Extractor
.
from_tiler
(
params
,
store
=
filename
,
tiler
=
tiler
)
ext
=
Extractor
.
from_tiler
(
params
,
store
=
filename
,
tiler
=
tiler
)
# RUN
# RUN
tps
=
general_config
.
get
(
"
tps
"
,
0
)
tps
=
general_config
.
get
(
"
tps
"
,
0
)
for
i
in
tqdm
(
range
(
0
,
tps
),
desc
=
image
.
name
):
frac_clogged_traps
=
0
t
=
perf_counter
()
for
i
in
tqdm
(
trap_info
=
tiler
.
run_tp
(
i
)
range
(
process_from
,
tps
),
desc
=
image
.
name
,
initial
=
process_from
logging
.
debug
(
f
"
Timing:Trap:
{
perf_counter
()
-
t
}
s
"
)
):
t
=
perf_counter
()
if
frac_clogged_traps
<
earlystop
[
"
thresh_pos_clogged
"
]:
writer
.
write
(
trap_info
,
overwrite
=
[])
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Writing-trap:
{
perf_counter
()
-
t
}
s
"
)
trap_info
=
tiler
.
run_tp
(
i
)
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Trap:
{
perf_counter
()
-
t
}
s
"
)
seg
=
runner
.
run_tp
(
i
)
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Segmentation:
{
perf_counter
()
-
t
}
s
"
)
writer
.
write
(
trap_info
,
overwrite
=
[])
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Writing-trap:
{
perf_counter
()
-
t
}
s
"
)
bwriter
.
write
(
seg
,
overwrite
=
[
"
mother_assign
"
])
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Writing-baby:
{
perf_counter
()
-
t
}
s
"
)
seg
=
runner
.
run_tp
(
i
)
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Segmentation:
{
perf_counter
()
-
t
}
s
"
)
ext
.
extract_pos
(
tps
=
[
i
])
t
=
perf_counter
()
logging
.
debug
(
f
"
Timing:Extraction:
{
perf_counter
()
-
t
}
s
"
)
bwriter
.
write
(
seg
,
overwrite
=
[
"
mother_assign
"
])
logging
.
debug
(
f
"
Timing:Writing-baby:
{
perf_counter
()
-
t
}
s
"
)
t
=
perf_counter
()
ext
.
extract_pos
(
tps
=
[
i
])
logging
.
debug
(
f
"
Timing:Extraction:
{
perf_counter
()
-
t
}
s
"
)
else
:
# Stop if more than 10% traps are clogged
logging
.
debug
(
f
"
EarlyStop:
{
earlystop
[
'
thresh_pos_clogged
'
]
*
100
}
% traps clogged at time point
{
i
}
"
)
print
(
f
"
Breaking experiment at time
{
i
}
with
{
frac_clogged_traps
}
clogged traps
"
)
break
if
i
>
earlystop
[
"
min_tp
"
]:
# Calculate the fraction of clogged traps
s
=
Signal
(
filename
)
df
=
s
[
"
/extraction/general/None/area
"
]
frac_clogged_traps
=
(
df
[
df
.
columns
[
i
-
earlystop
[
"
ntps_to_eval
"
]
:
i
]]
.
dropna
(
how
=
"
all
"
)
.
notna
()
.
groupby
(
"
trap
"
)
.
apply
(
sum
)
.
apply
(
np
.
nanmean
,
axis
=
1
)
>
earlystop
[
"
thresh_trap_clogged
"
]
).
mean
()
logging
.
debug
(
f
"
Quality:Clogged_traps:
{
frac_clogged_traps
}
"
)
print
(
"
Frac clogged traps:
"
,
frac_clogged_traps
)
# Run post processing
# Run post processing
post_proc_params
=
PostProcessorParameters
.
default
()
#
post_proc_params = PostProcessorParameters.default()
post_process
(
filename
,
post_proc_params
)
#
post_process(filename, post_proc_params)
return
True
return
True
except
Exception
as
e
:
# bug in the trap getting
except
Exception
as
e
:
# bug in the trap getting
print
(
f
"
Caught exception in worker thread (x =
{
name
}
):
"
)
print
(
f
"
Caught exception in worker thread (x =
{
name
}
):
"
)
...
@@ -128,74 +170,77 @@ def create_pipeline(image_id, **config):
...
@@ -128,74 +170,77 @@ def create_pipeline(image_id, **config):
if
session
:
if
session
:
session
.
close
()
session
.
close
()
@timed
(
'
Post-processing
'
)
@timed
(
"
Post-processing
"
)
def
post_process
(
filepath
,
params
):
def
post_process
(
filepath
,
params
):
pp
=
PostProcessor
(
filepath
,
params
)
pp
=
PostProcessor
(
filepath
,
params
)
tmp
=
pp
.
run
()
tmp
=
pp
.
run
()
return
tmp
return
tmp
@timed
(
'
Pipeline
'
)
# instantiating the decorator
@timed
(
"
Pipeline
"
)
def
run_config
(
config
):
def
run_config
(
config
):
# Config holds the general information, use in main
# Config holds the general information, use in main
# Steps holds the description of tasks with their parameters
# Steps holds the description of tasks with their parameters
# Steps: all holds general tasks
# Steps: all holds general tasks
# steps: strain_name holds task for a given strain
# steps: strain_name holds task for a given strain
expt_id
=
config
[
'
general
'
].
get
(
'
id
'
)
expt_id
=
config
[
"
general
"
].
get
(
"
id
"
)
distributed
=
config
[
'
general
'
].
get
(
'
distributed
'
,
0
)
distributed
=
config
[
"
general
"
].
get
(
"
distributed
"
,
0
)
strain_filter
=
config
[
'
general
'
].
get
(
'
strain
'
,
''
)
strain_filter
=
config
[
"
general
"
].
get
(
"
strain
"
,
""
)
root_dir
=
config
[
'
general
'
].
get
(
'
directory
'
,
'
output
'
)
root_dir
=
config
[
"
general
"
].
get
(
"
directory
"
,
"
output
"
)
root_dir
=
Path
(
root_dir
)
root_dir
=
Path
(
root_dir
)
print
(
'
Searching OMERO
'
)
print
(
"
Searching OMERO
"
)
# Do all initialisation
# Do all initialisation
with
Dataset
(
int
(
expt_id
))
as
conn
:
with
Dataset
(
int
(
expt_id
))
as
conn
:
image_ids
=
conn
.
get_images
()
image_ids
=
conn
.
get_images
()
directory
=
root_dir
/
conn
.
name
directory
=
root_dir
/
conn
.
unique_
name
if
not
directory
.
exists
():
if
not
directory
.
exists
():
directory
.
mkdir
(
parents
=
True
)
directory
.
mkdir
(
parents
=
True
)
# Download logs to use for metadata
# Download logs to use for metadata
conn
.
cache_logs
(
directory
)
conn
.
cache_logs
(
directory
)
# Modify to the configuration
# Modify to the configuration
config
[
'
general
'
][
'
directory
'
]
=
directory
config
[
"
general
"
][
"
directory
"
]
=
directory
# Filter
# Filter
image_ids
=
{
k
:
v
for
k
,
v
in
image_ids
.
items
()
if
k
.
startswith
(
image_ids
=
{
k
:
v
for
k
,
v
in
image_ids
.
items
()
if
k
.
startswith
(
strain_filter
)}
strain_filter
)}
if
distributed
!=
0
:
# Gives the number of simultaneous processes
if
distributed
!=
0
:
# Gives the number of simultaneous processes
with
Pool
(
distributed
)
as
p
:
with
Pool
(
distributed
)
as
p
:
results
=
p
.
map
(
lambda
x
:
create_pipeline
(
x
,
**
config
),
image_ids
.
items
())
results
=
p
.
map
(
lambda
x
:
create_pipeline
(
x
,
**
config
),
image_ids
.
items
())
p
.
terminate
()
return
results
return
results
else
:
# Sequential
else
:
# Sequential
results
=
[]
results
=
[]
for
k
,
v
in
image_ids
.
items
():
for
k
,
v
in
image_ids
.
items
():
r
=
create_pipeline
((
k
,
v
),
**
config
)
r
=
create_pipeline
((
k
,
v
),
**
config
)
results
.
append
(
r
)
results
.
append
(
r
)
def
initialise_logging
(
log_file
:
str
):
def
initialise_logging
(
log_file
:
str
):
logging
.
basicConfig
(
filename
=
log_file
,
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
filename
=
log_file
,
level
=
logging
.
DEBUG
)
for
v
in
logging
.
Logger
.
manager
.
loggerDict
.
values
():
for
v
in
logging
.
Logger
.
manager
.
loggerDict
.
values
():
try
:
try
:
if
not
v
.
name
.
startswith
([
'
extraction
'
,
'
core.io
'
]):
if
not
v
.
name
.
startswith
([
"
extraction
"
,
"
core.io
"
]):
v
.
disabled
=
True
v
.
disabled
=
True
except
:
except
:
pass
pass
def
parse_timing
(
log_file
):
def
parse_timing
(
log_file
):
timings
=
dict
()
timings
=
dict
()
# Open the log file
# Open the log file
with
open
(
log_file
,
'
r
'
)
as
f
:
with
open
(
log_file
,
"
r
"
)
as
f
:
# Line by line read
# Line by line read
for
line
in
f
.
read
().
splitlines
():
for
line
in
f
.
read
().
splitlines
():
if
not
line
.
startswith
(
'
DEBUG:root
'
):
if
not
line
.
startswith
(
"
DEBUG:root
"
):
continue
continue
words
=
line
.
split
(
'
:
'
)
words
=
line
.
split
(
"
:
"
)
# Only keep lines that include "Timing"
# Only keep lines that include "Timing"
if
'
Timing
'
in
words
:
if
"
Timing
"
in
words
:
# Split the last two into key, value
# Split the last two into key, value
k
,
v
=
words
[
-
2
:]
k
,
v
=
words
[
-
2
:]
# Dict[key].append(value)
# Dict[key].append(value)
if
k
not
in
timings
:
if
k
not
in
timings
:
timings
[
k
]
=
[]
timings
[
k
]
=
[]
...
@@ -205,43 +250,81 @@ def parse_timing(log_file):
...
@@ -205,43 +250,81 @@ def parse_timing(log_file):
def
visualise_timing
(
timings
:
dict
,
save_file
:
str
):
def
visualise_timing
(
timings
:
dict
,
save_file
:
str
):
plt
.
figure
().
clear
()
plt
.
figure
().
clear
()
plot_data
=
{
x
:
timings
[
x
]
for
x
in
timings
if
x
.
startswith
((
'
Trap
'
,
'
Writing
'
,
'
Segmentation
'
,
'
Extraction
'
))}
plot_data
=
{
sorted_keys
,
fixed_data
=
zip
(
*
sorted
(
plot_data
.
items
(),
key
=
operator
.
itemgetter
(
1
)))
x
:
timings
[
x
]
#Set up the graph parameters
for
x
in
timings
sns
.
set
(
style
=
'
whitegrid
'
)
if
x
.
startswith
((
"
Trap
"
,
"
Writing
"
,
"
Segmentation
"
,
"
Extraction
"
))
#Plot the graph
}
#sns.stripplot(data=fixed_data, size=1)
sorted_keys
,
fixed_data
=
zip
(
ax
=
sns
.
boxplot
(
data
=
fixed_data
,
whis
=
np
.
inf
,
width
=
.
05
)
*
sorted
(
plot_data
.
items
(),
key
=
operator
.
itemgetter
(
1
))
)
# Set up the graph parameters
sns
.
set
(
style
=
"
whitegrid
"
)
# Plot the graph
# sns.stripplot(data=fixed_data, size=1)
ax
=
sns
.
boxplot
(
data
=
fixed_data
,
whis
=
np
.
inf
,
width
=
0.05
)
ax
.
set
(
xlabel
=
"
Stage
"
,
ylabel
=
"
Time (s)
"
,
xticklabels
=
sorted_keys
)
ax
.
set
(
xlabel
=
"
Stage
"
,
ylabel
=
"
Time (s)
"
,
xticklabels
=
sorted_keys
)
ax
.
tick_params
(
axis
=
'
x
'
,
rotation
=
90
)
;
ax
.
tick_params
(
axis
=
"
x
"
,
rotation
=
90
)
ax
.
figure
.
savefig
(
save_file
,
bbox_inches
=
'
tight
'
,
transparent
=
True
)
ax
.
figure
.
savefig
(
save_file
,
bbox_inches
=
"
tight
"
,
transparent
=
True
)
return
return
# if __name__ == "__main__":
if
__name__
==
"
__main__
"
:
strain
=
'
Vph1
'
strain
=
"
YST_1512
"
tps
=
390
# exp = 18616
config
=
dict
(
# exp = 19232
general
=
dict
(
# exp = 19995
id
=
19303
,
# exp = 19993
distributed
=
5
,
exp
=
20191
tps
=
tps
,
# exp = 19831
strain
=
strain
,
with
Dataset
(
exp
)
as
conn
:
directory
=
'
../data/
'
imgs
=
conn
.
get_images
()
),
exp_name
=
conn
.
unique_name
tiler
=
dict
(),
baby
=
dict
(
tf_version
=
2
)
with
Image
(
list
(
imgs
.
values
())[
0
])
as
im
:
)
meta
=
im
.
metadata
log_file
=
'
../data/2tozero_Hxts_02/issues.log
'
tps
=
int
(
meta
[
"
size_t
"
])
initialise_logging
(
log_file
)
# tps = meta["size_t"]
save_timings
=
f
"
../data/2tozero_Hxts_02/timings_
{
strain
}
_
{
tps
}
.pdf
"
config
=
dict
(
timings_file
=
f
"
../data/2tozero_Hxts_02/timings_
{
strain
}
_
{
tps
}
.json
"
general
=
dict
(
# Run
id
=
exp
,
#run_config(config)
distributed
=
4
,
# Get timing results
tps
=
tps
,
timing
=
parse_timing
(
log_file
)
directory
=
"
../data/
"
,
# Visualise timings and save
strain
=
strain
,
visualise_timing
(
timing
,
save_timings
)
tile_size
=
117
,
# Dump the rest to json
),
with
open
(
timings_file
,
'
w
'
)
as
fd
:
# general=dict(id=19303, distributed=0, tps=tps, strain=strain, directory="../data/"),
json
.
dump
(
timing
,
fd
)
tiler
=
dict
(),
baby
=
dict
(
tf_version
=
2
),
earlystop
=
dict
(
min_tp
=
50
,
thresh_pos_clogged
=
0.3
,
thresh_trap_clogged
=
7
,
ntps_to_eval
=
5
,
),
)
# log_file = f"../data/{exp_name}/issues.log"
log_file
=
"
/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/issues.log
"
# initialise_logging(log_file)
save_timings
=
f
"
../data/
{
exp_name
}
/timings_
{
strain
}
_
{
tps
}
.pdf
"
timings_file
=
f
"
../data/
{
exp_name
}
/timings_
{
strain
}
_
{
tps
}
.json
"
# Run
run_config
(
config
)
# Get timing results
# timing = parse_timing(log_file)
# # Visualise timings and save
# visualise_timing(timing, save_timings)
# # Dump the rest to json
# with open(timings_file, "w") as fd:
# json.dump(timing, fd)
# filename = "/shared_libs/pydask/pipeline-core/data/2020_02_20_protAgg_downUpShift_2_0_2_Ura8_Ura8HA_Ura8HR_01/Ura8H360R030.h5"
# import h5py
# with h5py.File(filename, "r") as f:
# plt.imshow(f["cell_info/edgemasks/values"][0][-1])
This diff is collapsed.
Click to expand it.
scripts/improve_tiler.py
deleted
100644 → 0
+
0
−
187
View file @
50d77d41
#!/usr/bin/env python3
expts
=
[
18616
,
19232
,
19995
,
19993
,
20191
,
19831
]
# fetch images
test_imgs
=
[]
for
e
in
expts
:
with
Dataset
(
int
(
e
))
as
conn
:
image_ids
=
conn
.
get_images
()
for
im_id
in
image_ids
.
values
():
with
Image
(
im_id
)
as
image
:
dimg
=
image
.
data
print
(
"
computing
"
)
img
=
dimg
[
0
,
image
.
metadata
[
"
channels
"
].
index
(
"
Brightfield
"
),
2
,
...
].
compute
()
test_imgs
.
append
(
img
)
from
numpy
import
save
,
load
# save
for
i
,
nd
in
enumerate
(
test_imgs
):
save
(
"
raw_
"
+
str
(
i
)
+
"
.png
"
,
nd
)
# load
def
stretch_image
(
image
):
image
=
((
image
-
image
.
min
())
/
(
image
.
max
()
-
image
.
min
()))
*
255
minval
=
np
.
percentile
(
image
,
2
)
maxval
=
np
.
percentile
(
image
,
98
)
image
=
np
.
clip
(
image
,
minval
,
maxval
)
image
=
(
image
-
minval
)
/
(
maxval
-
minval
)
return
image
def
segment_traps
(
image
,
tile_size
,
downscale
=
0.4
):
# Make image go between 0 and 255
img
=
image
# Keep a memory of image in case need to re-run
stretched
=
stretch_image
(
image
)
img
=
stretch_image
(
image
)
# TODO Optimise the hyperparameters
disk_radius
=
int
(
min
([
0.01
*
x
for
x
in
img
.
shape
]))
min_area
=
0.2
*
(
tile_size
**
2
)
if
downscale
!=
1
:
img
=
transform
.
rescale
(
image
,
downscale
)
entropy_image
=
entropy
(
img
,
disk
(
disk_radius
))
if
downscale
!=
1
:
entropy_image
=
transform
.
rescale
(
entropy_image
,
1
/
downscale
)
# apply threshold
thresh
=
threshold_otsu
(
entropy_image
)
bw
=
closing
(
entropy_image
>
thresh
,
square
(
3
))
# remove artifacts connected to image border
cleared
=
clear_border
(
bw
)
# label image regions
label_image
=
label
(
cleared
)
areas
=
[
region
.
area
for
region
in
regionprops
(
label_image
)
if
region
.
area
>
min_area
and
region
.
area
<
tile_size
**
2
*
0.8
]
traps
=
(
np
.
array
(
[
region
.
centroid
for
region
in
regionprops
(
label_image
)
if
region
.
area
>
min_area
and
region
.
area
<
tile_size
**
2
*
0.8
]
)
.
round
()
.
astype
(
int
)
)
rprops
=
regionprops_table
(
label_image
,
properties
=
[
"
area
"
,
"
eccentricity
"
,
"
convex_area
"
,
"
feret_diameter_max
"
,
"
orientation
"
,
"
solidity
"
,
"
minor_axis_length
"
,
],
)
trapmask
=
(
rprops
[
"
area
"
]
>
min_area
)
&
(
rprops
[
"
area
"
]
<
tile_size
**
2
*
0.8
)
candidates
=
[
stretched
[
x
-
tile_size
//
2
:
x
+
tile_size
//
2
,
y
-
tile_size
//
2
:
y
+
tile_size
//
2
,
]
for
x
,
y
in
np
.
array
(
traps
).
round
().
astype
(
int
)
]
# valleys = [find_valley(c) for c in candidates]
from
copy
import
copy
bak
=
copy
(
candidates
)
candidates
=
[
bak
[
x
]
for
x
in
np
.
argsort
(
rprops
[
"
minor_axis_length
"
][
trapmask
])]
return
candidates
[:
5
]
# fig, axes = plt.subplots(5, 8)
# indices = np.concatenate((np.arange(20), -np.arange(1, 21)[::-1]))
# for i in range(5):
# for j in range(8):
# if i * 8 + j < len(candidates):
# # axes[i, j].imshow(candidates[i * 8 + j])
# axes[i, j].imshow(candidates[indices[i * 8 + j]])
# plt.show()
# chosen_trap_coords = np.round(traps[np.argsort(area)[len(area) // 2]]).astype(int)
# chosen_trap_coords = np.round(traps[np.argsort(ma)[len(ma) // 2]]).astype(int)
x
,
y
=
chosen_trap_coords
template
=
image
[
x
-
tile_size
//
2
:
x
+
tile_size
//
2
,
y
-
tile_size
//
2
:
y
+
tile_size
//
2
]
return
template
new_coords
=
identify_trap_locations
(
image
,
template
)
# def get_tile(tile_size=117):
# tile = np.ones((tile_size, tile_size))
# tile[1:-1, 1:-1] = False
# return tile
# tile = get_tile(tile_size)
# # tmp
# mask = np.zeros_like(image, dtype="bool")
# # for x, y in np.array(traps).round().astype(int):
# for x, y in new_coords:
# dist = int(tile_size / 2)
# size_okay = (
# np.array(mask[x - dist : x + dist + 1, y - dist : y + dist + 1].shape)
# == np.array(tile.shape)
# ).all()
# if size_okay:
# maxes = np.maximum.reduce(
# (mask[x - dist : x + dist + 1, y - dist : y + dist + 1], tile)
# )
# mask[x - dist : x + dist + 1, y - dist : y + dist + 1] = maxes
# from skimage.color import label2rgb
# traps_img = label2rgb(mask, image=stretched, bg_label=0, alpha=0.5)
if
len
(
traps
)
<
10
and
downscale
!=
1
:
print
(
"
Trying again.
"
)
return
segment_traps
(
image
,
tile_size
,
downscale
=
1
)
# return traps
return
traps_img
ncols
=
10
rands
=
np
.
random
.
randint
(
0
,
138
,
ncols
)
top_cands
=
[
segment_traps
(
test_imgs
[
r
],
tile_size
=
117
)
for
r
in
rands
]
fig
,
axes
=
plt
.
subplots
(
5
,
ncols
)
for
i
in
range
(
ncols
):
for
j
in
range
(
5
):
axes
[
j
,
i
].
imshow
(
top_cands
[
i
][
j
])
plt
.
show
()
# res = [segment_traps(im, tile_size=117) for im in test_imgs[rands]]
from
scipy.signal
import
find_peaks
def
find_valley
(
template
):
template
=
((
template
-
template
.
min
())
/
(
template
.
max
()
-
template
.
min
()))
*
255
summed
=
template
.
sum
(
axis
=
1
)
norm
=
summed
/
summed
.
max
()
find_peaks
(
norm
[
20
:
-
20
])
max1
,
max2
=
np
.
argsort
(
norm
[
peaks
[
0
]])[:
2
]
if
max2
<
max1
:
tmp
=
max2
max2
=
max1
max1
=
tmp
return
norm
[
max1
:
max2
].
min
()
for
i
,
im
in
enumerate
(
res
):
plt
.
imshow
(
im
)
plt
.
axis
(
"
off
"
)
plt
.
savefig
(
"
tiles
"
+
str
(
i
),
dpi
=
400
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment