Skip to content

Commit db87d77

Browse files
pichuancopybara-github
authored andcommitted
Fix binarize. (Reported in ⁠#286)
PiperOrigin-RevId: 303875906
1 parent 1ad453c commit db87d77

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

deepvariant/modeling.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,19 @@ class UnsupportedImageDimensionsError(Exception):
117117
def binarize(labels, target_class):
118118
"""Binarize labels and predictions.
119119
120-
The labels that are not equal to target_class parameter are set to zero.
120+
The labels that are equal to target_class parameter are set to 0, else
121+
set to 1.
121122
122123
Args:
123124
labels: the ground-truth labels for the examples.
124-
target_class: index of the class that is left as non-zero.
125+
target_class: index of the class that is left as zero.
125126
126127
Returns:
127128
Tensor of the same shape as labels.
128129
"""
129130
labels_binary = tf.compat.v1.where(
130131
tf.equal(labels, tf.constant(target_class, dtype=tf.int64)),
131-
tf.zeros_like(labels), labels)
132+
tf.zeros_like(labels), tf.ones_like(labels))
132133
return labels_binary
133134

134135

deepvariant/modeling_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,17 @@ def _run(tensor_to_run):
9898
modeling.is_encoded_variant_type(
9999
tensor, tf_utils.EncodedVariantType.INDEL)), [False, True] * 4)
100100

101+
@parameterized.parameters(
102+
dict(labels=[0, 2, 1, 0], target_class=0, expected=[0, 1, 1, 0]),
103+
dict(labels=[0, 2, 1, 0], target_class=1, expected=[1, 1, 0, 1]),
104+
dict(labels=[0, 2, 1, 0], target_class=2, expected=[1, 0, 1, 1]),
105+
)
106+
def test_binarize(self, labels, target_class, expected):
107+
with self.test_session() as sess:
108+
result = sess.run(
109+
modeling.binarize(np.array(labels), np.array(target_class)))
110+
self.assertListEqual(result.tolist(), expected)
111+
101112
@parameterized.parameters([True, False])
102113
def test_eval_metric_fn(self, include_variant_types):
103114
labels = tf.constant([1, 0], dtype=tf.int64)

0 commit comments

Comments
 (0)