Skip to content
13 changes: 13 additions & 0 deletions tensorflow_addons/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ py_library(
"connected_components.py",
"resampler_ops.py",
"compose_ops.py",
"iou_ops.py",
]),
data = [
":sparse_image_warp_test_data",
Expand Down Expand Up @@ -177,3 +178,15 @@ py_test(
":image",
],
)

py_test(
name = "iou_ops_test",
size = "medium",
srcs = [
"iou_ops_test.py",
],
main = "iou_ops_test.py",
deps = [
":image",
],
)
4 changes: 4 additions & 0 deletions tensorflow_addons/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@
from tensorflow_addons.image.translate_ops import translate
from tensorflow_addons.image.translate_ops import translate_xy
from tensorflow_addons.image.compose_ops import blend
from tensorflow_addons.image.iou_ops import iou
from tensorflow_addons.image.iou_ops import ciou
from tensorflow_addons.image.iou_ops import diou
from tensorflow_addons.image.iou_ops import giou
183 changes: 183 additions & 0 deletions tensorflow_addons/image/iou_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements IoUs."""

import tensorflow as tf
import math
import numpy as np
from typing import Union

CompatibleFloatTensorLike = Union[tf.Tensor, float, np.float32, np.float64]


def _get_v(
b1_height: CompatibleFloatTensorLike,
b1_width: CompatibleFloatTensorLike,
b2_height: CompatibleFloatTensorLike,
b2_width: CompatibleFloatTensorLike,
) -> tf.Tensor:
@tf.custom_gradient
def _get_grad_v(height, width):
arctan = tf.atan(tf.math.divide_no_nan(b1_width, b1_height)) - tf.atan(
tf.math.divide_no_nan(width, height)
)
v = 4 * ((arctan / math.pi) ** 2)

def _grad_v(dv):
gdw = dv * 8 * arctan * height / (math.pi ** 2)
gdh = -dv * 8 * arctan * width / (math.pi ** 2)
return [gdh, gdw]

return v, _grad_v

return _get_grad_v(b2_height, b2_width)


def _common_iou(
b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike, mode: str = "iou"
) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].
mode: one of ['iou', 'ciou', 'diou', 'giou'], decided to calculate IoU or CIoU or DIoU or GIoU.

Returns:
IoU loss float `Tensor`.
"""
b1 = tf.convert_to_tensor(b1)
if not b1.dtype.is_floating:
b1 = tf.cast(b1, tf.float32)
b2 = tf.cast(b2, b1.dtype)

def _inner():
zero = tf.convert_to_tensor(0.0, b1.dtype)
b1_ymin, b1_xmin, b1_ymax, b1_xmax = tf.unstack(b1, 4, axis=-1)
b2_ymin, b2_xmin, b2_ymax, b2_xmax = tf.unstack(b2, 4, axis=-1)
b1_width = tf.maximum(zero, b1_xmax - b1_xmin)
b1_height = tf.maximum(zero, b1_ymax - b1_ymin)
b2_width = tf.maximum(zero, b2_xmax - b2_xmin)
b2_height = tf.maximum(zero, b2_ymax - b2_ymin)
b1_area = b1_width * b1_height
b2_area = b2_width * b2_height

intersect_ymin = tf.maximum(b1_ymin, b2_ymin)
intersect_xmin = tf.maximum(b1_xmin, b2_xmin)
intersect_ymax = tf.minimum(b1_ymax, b2_ymax)
intersect_xmax = tf.minimum(b1_xmax, b2_xmax)
intersect_width = tf.maximum(zero, intersect_xmax - intersect_xmin)
intersect_height = tf.maximum(zero, intersect_ymax - intersect_ymin)
intersect_area = intersect_width * intersect_height

union_area = b1_area + b2_area - intersect_area
iou = tf.math.divide_no_nan(intersect_area, union_area)
if mode == "iou":
return iou

elif mode in ["ciou", "diou"]:
enclose_ymin = tf.minimum(b1_ymin, b2_ymin)
enclose_xmin = tf.minimum(b1_xmin, b2_xmin)
enclose_ymax = tf.maximum(b1_ymax, b2_ymax)
enclose_xmax = tf.maximum(b1_xmax, b2_xmax)

b1_center = tf.stack([(b1_ymin + b1_ymax) / 2, (b1_xmin + b1_xmax) / 2],axis=-1)
b2_center = tf.stack([(b2_ymin + b2_ymax) / 2, (b2_xmin + b2_xmax) / 2],axis=-1)
euclidean = tf.linalg.norm(b2_center - b1_center,axis=-1)
diag_length = tf.linalg.norm(
tf.stack([enclose_ymax - enclose_ymin, enclose_xmax - enclose_xmin],axis=-1),axis=-1
)
diou = iou - (euclidean ** 2) / (diag_length ** 2)
if mode == "ciou":
v = _get_v(b1_height, b1_width, b2_height, b2_width)
alpha = tf.math.divide_no_nan(v, ((1 - iou) + v))
return diou - alpha * v

return diou
elif mode == "giou":
enclose_ymin = tf.minimum(b1_ymin, b2_ymin)
enclose_xmin = tf.minimum(b1_xmin, b2_xmin)
enclose_ymax = tf.maximum(b1_ymax, b2_ymax)
enclose_xmax = tf.maximum(b1_xmax, b2_xmax)
enclose_width = tf.maximum(zero, enclose_xmax - enclose_xmin)
enclose_height = tf.maximum(zero, enclose_ymax - enclose_ymin)
enclose_area = enclose_width * enclose_height
giou = iou - tf.math.divide_no_nan(
(enclose_area - union_area), enclose_area
)
return giou
else:
raise ValueError(
"Value of mode should be one of ['iou','giou','ciou','diou']"
)

return tf.squeeze(_inner())


def iou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
IoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "iou")


def ciou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
CIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "ciou")


def diou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
DIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "diou")


def giou(b1: CompatibleFloatTensorLike, b2: CompatibleFloatTensorLike) -> tf.Tensor:
"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].

Returns:
GIoU loss float `Tensor`.
"""
return _common_iou(b1, b2, "giou")
87 changes: 87 additions & 0 deletions tensorflow_addons/image/iou_ops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for IoU losses."""

from absl.testing import parameterized

import numpy as np
import tensorflow as tf
from tensorflow_addons.utils import test_utils
from tensorflow_addons.image import iou, ciou, diou, giou


@test_utils.run_all_in_graph_and_eager_modes
class IoUTest(tf.test.TestCase, parameterized.TestCase):
"""IoU test class."""

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_ious_loss(self, dtype):
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], dtype=dtype)
boxes2 = tf.constant(
[[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], dtype=dtype
)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [
[0.125, 0.0],
[-0.4088933645154844, -0.5487535732151345],
[-0.4065315315315314, -0.5315315315315314],
[-0.07500000298023224, -0.9333333373069763],
]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(boxes1, boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_different_shapes(self, dtype):
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], dtype=dtype)
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype)
expand_boxes1 = tf.expand_dims(boxes1, -2)
expand_boxes2 = tf.expand_dims(boxes2, 0)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [
[0.125, 0.0625],
[-0.0117957952481038, -0.1123530805529542],
[-0.0094339622641511, -0.0719339622641511],
[-0.075, -0.3660714285714286],
]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(expand_boxes1, expand_boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_one_bbox(self, dtype):
boxes1 = tf.constant([4.0, 3.0, 7.0, 5.0], dtype=dtype)
boxes2 = tf.constant([3.0, 4.0, 6.0, 8.0], dtype=dtype)
losses = [iou, ciou, diou, giou]
expected_results = [
tf.constant(expected_result, dtype=dtype)
for expected_result in [0.125, 0.000686947503852, 0.0030487804878, -0.075]
]
for iou_loss_imp, expected_result in zip(losses, expected_results):
with self.subTest():
loss = iou_loss_imp(boxes1, boxes2)
self.assertAllCloseAccordingToType(loss, expected_result)


if __name__ == "__main__":
tf.test.main()
9 changes: 5 additions & 4 deletions tensorflow_addons/losses/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ py_library(
"__init__.py",
"contrastive.py",
"focal_loss.py",
"giou_loss.py",
"iou_loss.py",
"lifted.py",
"metric_learning.py",
"npairs.py",
Expand All @@ -18,6 +18,7 @@ py_library(
],
deps = [
"//tensorflow_addons/activations",
"//tensorflow_addons/image",
"//tensorflow_addons/utils",
],
)
Expand Down Expand Up @@ -47,12 +48,12 @@ py_test(
)

py_test(
name = "giou_loss_test",
name = "iou_loss_test",
size = "small",
srcs = [
"giou_loss_test.py",
"iou_loss_test.py",
],
main = "giou_loss_test.py",
main = "iou_loss_test.py",
deps = [
":losses",
],
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
sigmoid_focal_crossentropy,
SigmoidFocalCrossEntropy,
)
from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss
from tensorflow_addons.losses.iou_loss import iou_loss, IoULoss
from tensorflow_addons.losses.iou_loss import ciou_loss, CIoULoss
from tensorflow_addons.losses.iou_loss import diou_loss, DIoULoss
from tensorflow_addons.losses.iou_loss import giou_loss, GIoULoss
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import (
Expand Down
Loading