Skip to content

Commit

Permalink
Implement correct usage of crop parameter in CellTracker (#108)
Browse files Browse the repository at this point in the history
* Remove deprecated post processing functions from the tracker

* Add crop mode params to the CellTracker and use when retrieving features

* Missing comma

* Pin numpy to <1.24
  • Loading branch information
msschwartz21 authored Dec 20, 2022
1 parent 8d2b443 commit 199dde2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 230 deletions.
235 changes: 12 additions & 223 deletions deepcell_tracking/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class CellTracker(object): # pylint: disable=useless-object-inheritance
dtype (str): data type for features, can be 'float32', 'float16', etc.
data_format (str): determines the order of the channel axis,
one of 'channels_first' and 'channels_last'.
crop_mode (str): Whether to do a fixed crop or to crop and resize
to create the appearance features
norm (bool): Whether to remove non cell features and normalize the
foreground pixels by zero-meaning and dividing by the standard
deviation. Applies to fixed crop mode only.
"""

def __init__(self,
Expand All @@ -91,6 +96,8 @@ def __init__(self,
division=0.9,
track_length=5,
embedding_axis=0,
crop_mode='resize',
norm=True,
dtype='float32',
data_format='channels_last'):

Expand Down Expand Up @@ -123,6 +130,8 @@ def __init__(self,
self.dtype = dtype
self.track_length = track_length
self.embedding_axis = embedding_axis
self.crop_mode = crop_mode
self.norm = norm

self.a_matrix = []
self.c_matrix = []
Expand Down Expand Up @@ -211,7 +220,9 @@ def _est_feats(self):

frame_features = get_image_features(
self.X[frame], self.y[frame],
appearance_dim=self.appearance_dim)
appearance_dim=self.appearance_dim,
crop_mode=self.crop_mode,
norm=self.norm)

for cell_idx, cell_id in enumerate(frame_features['labels']):
self.id_to_idx[cell_id] = cell_idx
Expand Down Expand Up @@ -692,47 +703,6 @@ def dataframe(self, **kwargs):

return dataframe

def postprocess(self, filename=None, time_excl=9):
"""Use graph postprocessing to eliminate false positive division errors
using a graph-based detection method. False positive errors are when a
cell is noted as a daughter of itself before the actual division occurs.
If a filename is passed, save the state of the cell tracker to a .trk
('track') file. time_excl is the minimum number of frames expected to
exist between legitimate divisions
"""

# Load data
track_review_dict = self._track_review_dict()

# Prep data
tracked = track_review_dict['y_tracked'].astype('uint16')
lineage = track_review_dict['tracks']

# Identify false positives (FPs)
G = self._track_to_graph(lineage)
FPs = self._flag_false_pos(G, time_excl)
FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1]))
FPs_sorted = self._review_candidate_nodes(FPs_candidates)

# If FPs exist, use the results to correct
while len(FPs_sorted) != 0:

lineage, tracked = self._remove_false_pos(lineage, tracked, FPs_sorted[0])
G = self._track_to_graph(lineage)
FPs = self._flag_false_pos(G, time_excl)
FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1]))
FPs_sorted = self._review_candidate_nodes(FPs_candidates)

# Make sure the assignment is correct
track_review_dict['y_tracked'] = tracked
track_review_dict['tracks'] = lineage

# Save information to a track file file if requested
if filename is not None:
self.dump(filename, track_review_dict)

return track_review_dict

def dump(self, filename, track_review_dict=None):
"""Writes the state of the cell tracker to a .trk ('track') file.
Includes raw & tracked images, and a lineage.json for parent/daughter
Expand All @@ -752,184 +722,3 @@ def dump(self, filename, track_review_dict=None):
lineage=track_review_dict['tracks'],
raw=track_review_dict['X'],
tracked=track_review_dict['y_tracked'])

def _track_to_graph(self, tracks):
"""Create a graph from the lineage information"""
Dattr = {}
edges = pd.DataFrame()

for L in tracks.values():
# Calculate node ids
cellid = ['{}_{}'.format(L['label'], f) for f in L['frames']]
# Add edges from cell ids
edges = edges.append(pd.DataFrame({'source': cellid[0:-1],
'target': cellid[1:]}))

# Collect any division attributes
if L['frame_div'] is not None:
Dattr['{}_{}'.format(L['label'], L['frame_div'] - 1)] = {'division': True}

# Create any daughter-parent edges
if L['parent'] is not None:
source = '{}_{}'.format(L['parent'], min(L['frames']) - 1)
target = '{}_{}'.format(L['label'], min(L['frames']))
edges = edges.append(pd.DataFrame({'source': [source],
'target': [target]}))

G = nx.from_pandas_edgelist(edges, source='source', target='target')
nx.set_node_attributes(G, Dattr)
return G

def _flag_false_pos(self, G, time_excl):
"""Examine graph for false positive nodes
"""

# TODO: Current implementation may eliminate some divisions at the edge of the frame -
# Further research needed

# Identify false positive nodes
node_fix = []
for g in (G.subgraph(c) for c in nx.connected_components(G)):
div_nodes = [n for n, d in g.nodes(data=True) if d.get('division')]
if len(div_nodes) > 1:
for nd in div_nodes:
if g.degree(nd) == 2:
# Check how close suspected FP is to other known divisions

keep_div = True
for div_nd in div_nodes:
if div_nd != nd:
time_spacing = abs(int(nd.split('_')[1]) -
int(div_nd.split('_')[1]))
# If division is sufficiently far away
# we should exclude it from FP list
if time_spacing > time_excl:
keep_div = False

if keep_div is True:
node_fix.append(nd)

# Add supplementary information for each false positive
D = {}
for node in node_fix:
D[node] = {
'false positive': node,
'neighbors': list(G.neighbors(node)),
'connected lineages': set([int(node.split('_')[0])
for node in nx.node_connected_component(G, node)])
}

return D

def _review_candidate_nodes(self, FPs_candidates):
""" review candidate false positive nodes and remove any errant degree 2 nodes.
"""
FPs_presort = {}
# review candidate false positive nodes and remove any errant degree 2 nodes
for candidate_node in FPs_candidates:
node = candidate_node[0]
node_info = candidate_node[1]

neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)]
for neighbor in node_info['neighbors']:
neighbor_label = int(neighbor.split('_')[0])
neighbor_frame = int(neighbor.split('_')[1])
neighbors.append((neighbor_label, neighbor_frame))

# if this cell only exists in one frame (and then it divides) but its 2 neighbors
# both exist in the same frame it will be a degree 2 node but not be a false positive
if neighbors[0][1] != neighbors[1][1]:
FPs_presort[node] = node_info

FPs_sorted = sorted(FPs_presort.items(), key=lambda v: int(v[0].split('_')[1]))

return FPs_sorted

def _remove_false_pos(self, lineage, tracked, FP_info):
""" Remove nodes that have been identified as false positive divisions.
"""
node = FP_info[0]
node_info = FP_info[1]

fp_label = int(node.split('_')[0])
fp_frame = int(node.split('_')[1])

neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)]
for neighbor in node_info['neighbors']:
neighbor_label = int(neighbor.split('_')[0])
neighbor_frame = int(neighbor.split('_')[1])
neighbors.append((neighbor_label, neighbor_frame))

# Verify that the FP node only 2 neighbors - 1 before it and one after it
if len(neighbors) == 2:
# order the neighbors such that the time (frame order) is respected
if neighbors[0][1] > neighbors[1][1]:
temp = neighbors[0]
neighbors[0] = neighbors[1]
neighbors[1] = temp

# Decide which labels to extend and which to remove

# Neighbor_1 has same label as fp - the actual division hasnt occurred yet
if fp_label == neighbors[0][0]:
# The model mistakenly identified a division before the actual division occurred
label_to_remove = neighbors[1][0]
label_to_extend = neighbors[0][0]

# Give all of the errant divisions information to the correct track
lineage[label_to_extend]['frames'].extend(lineage[label_to_remove]['frames'])
lineage[label_to_extend]['daughters'] = lineage[label_to_remove]['daughters']
lineage[label_to_extend]['frame_div'] = lineage[label_to_remove]['frame_div']

# Adjust the parent information for the actual daughters
daughter_labels = lineage[label_to_remove]['daughters']
for daughter in daughter_labels:
lineage[daughter]['parent'] = lineage[label_to_remove]['parent']

# Remove the errant node from the annotated images
channel = 0 # These images should only have one channel
for frame in lineage[label_to_remove]['frames']:
label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove)
tracked[frame, :, :, channel][label_loc] = label_to_extend

# Remove the errant node from the lineage
del lineage[label_to_remove]

# Neighbor_2 has same label as fp - the actual division ocurred &
# the model mistakenly allowed another
# elif fp_label == neighbors[1][0]:
# The model mistakenly identified a division after
# the actual division occurred
# label_to_remove = fp_label

# Neither neighbor has same label as fp - the actual division
# ocurred & the model mistakenly allowed another
else:
# The model mistakenly identified a division after the actual division occurred
label_to_remove = fp_label
label_to_extend = neighbors[1][0]

# Give all of the errant divisions information to the correct track
lineage[label_to_extend]['frames'] = \
lineage[fp_label]['frames'] + lineage[label_to_extend]['frames']
lineage[label_to_extend]['parent'] = lineage[fp_label]['parent']

# Adjust the parent information for the actual daughter
parent_label = lineage[fp_label]['parent']
for d_idx, daughter in enumerate(lineage[parent_label]['daughters']):
if daughter == fp_label:
lineage[parent_label]['daughters'][d_idx] = label_to_extend

# Remove the errant node from the annotated images
channel = 0 # These images should only have one channel
for frame in lineage[label_to_remove]['frames']:
label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove)
tracked[frame, :, :, channel][label_loc] = label_to_extend

# Remove the errant node
del lineage[label_to_remove]

else:
self.logger.error('Error: More than 2 neighbor nodes')

return lineage, tracked
7 changes: 1 addition & 6 deletions deepcell_tracking/tracking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,7 @@ def test_track_cells(self, tmpdir):
with pytest.raises(ValueError):
tracker.dataframe(bad_value=-1)

# test tracker.postprocess
tempdir = str(tmpdir)
path = os.path.join(tempdir, 'postprocess.xyz')
tracker.postprocess(filename=path)
post_saved_path = os.path.join(tempdir, 'postprocess.trk')
assert os.path.isfile(post_saved_path)

# test tracker.dump
path = os.path.join(tempdir, 'test.xyz')
Expand All @@ -191,7 +186,7 @@ def test_track_cells(self, tmpdir):
assert os.path.isfile(os.path.join(tempdir, 'all.trks'))

# test load_trks
data = trk_io.load_trks(post_saved_path)
data = trk_io.load_trks(dump_saved_path)
assert isinstance(data['lineages'], list)
assert all(isinstance(d, dict) for d in data['lineages'])
np.testing.assert_equal(data['X'], tracker.X)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
download_url=DOWNLOAD_URL,
license=LICENSE,
install_requires=['networkx>=2.1',
'numpy',
'numpy<1.24',
'pandas',
'scipy',
'scikit-image>=0.14.5',
Expand Down

0 comments on commit 199dde2

Please sign in to comment.