Skip to content
Snippets Groups Projects
Commit ab99d9fb authored by pswain's avatar pswain
Browse files

unfinished docs on tracks

parent dfe826be
No related branches found
No related tags found
No related merge requests found
""" """
Functions to process, filter and merge tracks. Functions to process, filter, and merge tracks.
"""
We call two tracks contiguous if they are adjacent in time: the
maximal time point of one is one time point less than the
minimal time point of the other.
# from collections import Counter A right track can have multiple potential left tracks. We must
pick the best.
"""
import typing as t import typing as t
from copy import copy from copy import copy
...@@ -17,6 +22,76 @@ from utils_find_1st import cmp_larger, find_1st ...@@ -17,6 +22,76 @@ from utils_find_1st import cmp_larger, find_1st
from postprocessor.core.processes.savgol import non_uniform_savgol from postprocessor.core.processes.savgol import non_uniform_savgol
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, choose the one with lowest error.
Parameters
----------
tracks: (m x n) Signal
A Signal, usually area, dataframe where rows are cell tracks and
columns are time points.
tol: float or int
threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
window: int
value of window used for savgol_filter
degree: int
value of polynomial degree passed to savgol_filter
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# get contiguous tracks
if smooth:
# specialise to tracks with growing cells and of long duration
clean = clean_tracks(tracks, min_duration=window + 1, min_gr=0.9)
contigs = clean.groupby(["trap"]).apply(get_contiguous_pairs)
else:
contigs = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
# remove traps with no contiguous tracks
contigs = contigs.loc[contigs.apply(len) > 0]
# flatten to (trap, cell_id) pairs
flat = set([k for v in contigs.values for i in v for j in i for k in j])
# make a data frame of contiguous tracks with the tracks as arrays
if smooth:
smoothed_tracks = clean.loc[flat].apply(
lambda x: non_uniform_savgol(x.index, x.values, window, degree),
axis=1,
)
else:
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# get the Signal values for neighbouring end points of contiguous tracks
actual_edges = contigs.apply(lambda x: get_edge_values(x, smoothed_tracks))
# get the predicted values
predicted_edges = contigs.apply(
lambda x: get_predicted_edge_values(x, smoothed_tracks, window)
)
# Prediction of pre and mean of post
prediction_costs = predicted_edges.apply(get_dMetric_wrap, tol=tol)
solutions = [
solve_matrices_wrap(cost, edges, tol=tol)
for (trap_id, cost), edges in zip(
prediction_costs.items(), actual_edges
)
]
breakpoint()
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contigs.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def load_test_dset(): def load_test_dset():
"""Load development dataset to test functions.""" """Load development dataset to test functions."""
return pd.DataFrame( return pd.DataFrame(
...@@ -45,46 +120,21 @@ def max_nonstop_ntps(track: pd.Series) -> int: ...@@ -45,46 +120,21 @@ def max_nonstop_ntps(track: pd.Series) -> int:
return max(consecutive_nonas_grouped) return max(consecutive_nonas_grouped)
def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series: def get_avg_gr(track: pd.Series) -> float:
return tracks.apply(max_ntps, axis=1) """Get average growth rate for a track."""
def get_avg_gr(track: pd.Series) -> int:
"""
Get average growth rate for a track.
:param tracks: Series with volume and timepoints as indices
"""
ntps = max_ntps(track) ntps = max_ntps(track)
vals = track.dropna().values vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps gr = (vals[-1] - vals[0]) / ntps
return gr return gr
def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
"""
Get average growth rate for a group of tracks
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
"""
return tracks.apply(get_avg_gr, axis=1)
def clean_tracks( def clean_tracks(
tracks, min_len: int = 15, min_gr: float = 1.0 tracks, min_duration: int = 15, min_gr: float = 1.0
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """Remove small non-growing tracks and return the reduced data frame."""
Clean small non-growing tracks and return the reduced dataframe ntps = tracks.apply(max_ntps, axis=1)
grs = tracks.apply(get_avg_gr, axis=1)
:param tracks: (m x n) dataframe where rows are cell tracks and growing_long_tracks = tracks.loc[(ntps >= min_duration) & (grs > min_gr)]
columns are timepoints
:param min_len: int number of timepoints cells must have not to be removed
:param min_gr: float Minimum mean growth rate to assume an outline is growing
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
return growing_long_tracks return growing_long_tracks
...@@ -191,125 +241,78 @@ def join_track_pair(target, source): ...@@ -191,125 +241,78 @@ def join_track_pair(target, source):
return tgt_copy return tgt_copy
def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict: def get_edge_values(contigs_ids, smoothed_tracks):
""" """Get Signal values for adjacent end points for each contiguous track."""
Get the pair of track (without repeats) that have a smaller error than the values = [
tolerance. If there is a track that can be assigned to two or more other (
ones, choose the one with lowest error. [get_value(smoothed_tracks.loc[pre_id], -1) for pre_id in pre_ids],
[
get_value(smoothed_tracks.loc[post_id], 0)
for post_id in post_ids
],
)
for pre_ids, post_ids in contigs_ids
]
return values
Parameters
----------
tracks: (m x n) Signal
A Signal, usually area, dataframe where rows are cell tracks and
columns are time points.
tol: float or int
threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
window: int
value of window used for savgol_filter
degree: int
value of polynomial degree passed to savgol_filter
"""
# only consider time series with more than two non-NaN data points
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# smooth all relevant tracks def get_predicted_edge_values(contigs_ids, smoothed_tracks, window):
if smooth: """
# Apply savgol filter TODO fix nans affecting edge placing Find neighbouring values of two contiguous tracks.
clean = clean_tracks(
tracks, min_len=window + 1, min_gr=0.9
) # get useful tracks
def savgol_on_srs(x):
return non_uniform_savgol(x.index, x.values, window, degree)
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
else:
contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# fetch edges from ids TODO (IF necessary, here we can compare growth rates) Predict the next value for the leftmost track using window values
def idx_to_edge(preposts): and find the mean of the initial window values of the rightmost
return [ track.
( """
[get_val(smoothed_tracks.loc[pre], -1) for pre in pres], result = []
[get_val(smoothed_tracks.loc[post], 0) for post in posts], for pre_ids, post_ids in contigs_ids:
pre_res = []
# left contiguous tracks
for pre_id in pre_ids:
# get last window values of a track
y = get_values_i(smoothed_tracks.loc[pre_id], -window)
# predict next value
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
) )
for pres, posts in preposts # right contiguous tracks
pos_res = [
# mean value of initial window values of a track
get_mean_value_i(smoothed_tracks.loc[post_id], window)
for post_id in post_ids
] ]
result.append([pre_res, pos_res])
return result
def idx_to_pred(preposts):
result = []
for pres, posts in preposts:
pre_res = []
for pre in pres:
y = get_last_i(smoothed_tracks.loc[pre], -window)
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
pos_res = [
get_means(smoothed_tracks.loc[post], window) for post in posts
]
result.append([pre_res, pos_res])
return result
edges = contig.apply(idx_to_edge) # Raw edges
# edges_mean = contig.apply(idx_to_means) # Mean of both
pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post
# edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
# edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
solutions = []
# for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contig.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
def get_value(x, n):
"""Get value from an array ignoring NaN."""
val = x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
return val
def get_val(x, n):
return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
def get_mean_value_i(x, i):
def get_means(x, i): """Get track's mean Signal value from values either from or up to an index."""
if not len(x[~np.isnan(x)]): if not len(x[~np.isnan(x)]):
return np.nan return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else: else:
v = x[~np.isnan(x)][i:] if i > 0:
return np.nanmean(v) v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
def get_last_i(x, i): def get_values_i(x, i):
"""Get track's Signal values either from or up to an index."""
if not len(x[~np.isnan(x)]): if not len(x[~np.isnan(x)]):
return np.nan return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else: else:
v = x[~np.isnan(x)][i:] if i > 0:
return v v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
def localid_to_idx(local_ids, contig_trap): def localid_to_idx(local_ids, contig_trap):
...@@ -338,57 +341,55 @@ def get_vec_closest_pairs(lst: List, **kwargs): ...@@ -338,57 +341,55 @@ def get_vec_closest_pairs(lst: List, **kwargs):
def get_dMetric_wrap(lst: List, **kwargs): def get_dMetric_wrap(lst: List, **kwargs):
"""Calculate dMetric on a list."""
return [get_dMetric(*sublist, **kwargs) for sublist in lst] return [get_dMetric(*sublist, **kwargs) for sublist in lst]
def solve_matrices_wrap(dMetric: List, edges: List, **kwargs): def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
"""Calculate solve_matrices on a list."""
return [ return [
solve_matrices(mat, edgeset, **kwargs) solve_matrices(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges) for mat, edgeset in zip(dMetric, edges)
] ]
def get_dMetric( def get_dMetric(pre: List[float], post: List[float], tol):
pre: List[float], post: List[float], tol: Union[float, int] = 1 """
): Calculate a cost matrix based on the difference between two Signal
"""Calculate a cost matrix values.
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
Parameters
----------
pre: list of floats
Values of the Signal for left contiguous tracks.
post: list of floats
Values of the Signal for right contiguous tracks.
""" """
if len(pre) > len(post): if len(pre) > len(post):
dMetric = np.abs(np.subtract.outer(post, pre)) dMetric = np.abs(np.subtract.outer(post, pre))
else: else:
dMetric = np.abs(np.subtract.outer(pre, post)) dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = ( # replace NaNs with maximal cost values
tol + 1 + np.nanmax(dMetric) dMetric[np.isnan(dMetric)] = tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
return dMetric return dMetric
def solve_matrices( def solve_matrices(cost: np.ndarray, edges: List, tol: Union[float, int] = 1):
dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
):
""" """
Solve the distance matrices obtained in get_dMetric and/or merged from Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices. independent dMetric matrices.
""" """
ids = solve_matrix(dMetric) ids = solve_matrix(cost)
if not len(ids[0]): if len(ids[0]):
pre, post = edges
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
else:
return [] return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
def get_closest_pairs( def get_closest_pairs(
...@@ -412,37 +413,31 @@ def get_closest_pairs( ...@@ -412,37 +413,31 @@ def get_closest_pairs(
def solve_matrix(dMetric): def solve_matrix(dMetric):
""" """Arrange indices to the cost matrix in order of increasing cost."""
Solve cost matrix focusing on getting the smallest cost at each iteration.
input
:param dMetric: np.array cost matrix
returns
tuple of np.arrays indicating picks with lowest individual value
"""
glob_is = [] glob_is = []
glob_js = [] glob_js = []
if (~np.isnan(dMetric)).any(): if (~np.isnan(dMetric)).any():
tmp = copy(dMetric) lMetric = copy(dMetric)
std = sorted(tmp[~np.isnan(tmp)]) sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
while (~np.isnan(std)).any(): while (~np.isnan(sortedMetric)).any():
v = std[0] # indices of point with minimal cost
i_s, j_s = np.where(tmp == v) i_s, j_s = np.where(lMetric == sortedMetric[0])
i = i_s[0] i = i_s[0]
j = j_s[0] j = j_s[0]
tmp[i, :] += np.nan # store this point
tmp[:, j] += np.nan
glob_is.append(i) glob_is.append(i)
glob_js.append(j) glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)]) # remove from lMetric
return (np.array(glob_is), np.array(glob_js)) lMetric[i, :] += np.nan
lMetric[:, j] += np.nan
sortedMetric = sorted(lMetric[~np.isnan(lMetric)])
indices = (np.array(glob_is), np.array(glob_js))
breakpoint()
return indices
def plot_joinable(tracks, joinable_pairs): def plot_joinable(tracks, joinable_pairs):
""" """Convenience plotting function for debugging."""
Convenience plotting function for debugging and data vis
"""
nx = 8 nx = 8
ny = 8 ny = 8
_, axes = plt.subplots(nx, ny) _, axes = plt.subplots(nx, ny)
...@@ -467,53 +462,31 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list: ...@@ -467,53 +462,31 @@ def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
""" """
Get all pair of contiguous track ids from a tracks data frame. Get all pair of contiguous track ids from a tracks data frame.
For two tracks to be contiguous, they must be exactly adjacent.
Parameters Parameters
---------- ----------
tracks: pd.Dataframe tracks: pd.Dataframe
A dataframe for one trap where rows are cell tracks and columns A dataframe where rows are cell tracks and columns are time
are time points. points.
""" """
# time points bounding a tracklet of non-NaN values # TODO add support for skipping time points
# find time points bounding tracks of non-NaN values
mins, maxs = [ mins, maxs = [
tracks.notna().apply(np.where, axis=1).apply(fn) tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max) for fn in (np.min, np.max)
] ]
# mins.name = "min_tpt"
# maxs.name = "max_tpt"
# df = pd.merge(mins, maxs, right_index=True, left_index=True)
# df["duration"] = df.max_tpt - df.min_tpt
#
# flip so that time points become the index
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist()) mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist()) maxs_d = maxs.groupby(maxs).apply(lambda x: x.index.tolist())
# reduce minimal time point to make a right track overlap with a left track
mins_d.index = mins_d.index - 1
# find common end points
common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True) common = sorted(set(mins_d.index).intersection(maxs_d.index), reverse=True)
return [(maxs_d[t], mins_d[t]) for t in common] contigs = [(maxs_d[t], mins_d[t]) for t in common]
return contigs
# def fit_track(track: pd.Series, obj=None):
# if obj is None:
# obj = objective
# x = track.dropna().index
# y = track.dropna().values
# popt, _ = curve_fit(obj, x, y)
# return popt
# def interpolate(track, xs) -> list:
# '''
# Interpolate next timepoint from a track
# :param track: pd.Series of volume growth over a time period
# :param t: int timepoint to interpolate
# '''
# popt = fit_track(track)
# # perr = np.sqrt(np.diag(pcov))
# return objective(np.array(xs), *popt)
# def objective(x,a,b,c,d) -> float:
# # return (a)/(1+b*np.exp(c*x))+d
# return (((x+d)*a)/((x+d)+b))+c
# def cand_pairs_to_dict(candidates):
# d={x:[] for x,_ in candidates}
# for x,y in candidates:
# d[x].append(y)
# return d
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment