Skip to content

Commit

Permalink
Fix converters (keras-team#1596)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Mar 28, 2023
1 parent f16d41f commit b17f823
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
7 changes: 3 additions & 4 deletions keras_cv/bounding_box/converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The KerasCV Authors
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -511,9 +511,8 @@ def _image_shape(images, image_shape, boxes):
width = tf.reshape(
tf.reduce_max(images.row_lengths(axis=2), 1), (-1, 1)
)
if isinstance(boxes, tf.RaggedTensor):
height = tf.expand_dims(height, axis=-1)
width = tf.expand_dims(width, axis=-1)
height = tf.expand_dims(height, axis=-1)
width = tf.expand_dims(width, axis=-1)
else:
height, width = image_shape[0], image_shape[1]
return tf.cast(height, boxes.dtype), tf.cast(width, boxes.dtype)
28 changes: 25 additions & 3 deletions keras_cv/bounding_box/converters_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The KerasCV Authors
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -58,6 +58,8 @@

images = tf.ones([2, 1000, 1000, 3])

ragged_classes = tf.ragged.constant([[0], [0]], dtype=tf.float32)

boxes = {
"xyxy": xyxy_box,
"center_xywh": center_xywh_box,
Expand Down Expand Up @@ -106,8 +108,8 @@ def test_converters(self, source, target):

@parameterized.named_parameters(*test_image_ragged)
def test_converters_ragged_images(self, source, target):
source_box = boxes_ragged_images[source]
target_box = boxes_ragged_images[target]
source_box = _raggify(boxes_ragged_images[source])
target_box = _raggify(boxes_ragged_images[target])
self.assertAllClose(
bounding_box.convert_format(
source_box, source=source, target=target, images=ragged_images
Expand Down Expand Up @@ -190,6 +192,26 @@ def test_ragged_bounding_box_with_image_shape(self, source, target):
target_box,
)

@parameterized.named_parameters(*test_image_ragged)
def test_dense_bounding_box_with_ragged_images(self, source, target):
source_box = _raggify(boxes_ragged_images[source])
target_box = _raggify(boxes_ragged_images[target])
source_bounding_boxes = {"boxes": source_box, "classes": ragged_classes}
source_bounding_boxes = bounding_box.to_dense(source_bounding_boxes)

result_bounding_boxes = bounding_box.convert_format(
source_bounding_boxes,
source=source,
target=target,
images=ragged_images,
)
result_bounding_boxes = bounding_box.to_ragged(result_bounding_boxes)

self.assertAllClose(
result_bounding_boxes["boxes"],
target_box,
)


def _raggify(tensor):
tensor = tf.squeeze(tensor, axis=0)
Expand Down

0 comments on commit b17f823

Please sign in to comment.