From 4dfa66f94d23221652b8f8dab7433becab941689 Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 26 Sep 2024 12:36:43 -0700 Subject: [PATCH] Port Mask R-CNN to Keras3 (#2483) * Port Faster R-CNN to Keras3 * Port Mask R-CNN to Keras3 * added the processing of mask predictions * "Revert changes "# * added faster_rcnn as backbone for mask_rcnn * add mask predictions in faster_rcnn * remove multiple variable declaration in roi_sampler * removing changes to nms.py and roi_sampler.py * add newline at end to revert * removed extraneous whitespace * removing changes to fasterrcnn.py * Move files to maskrcnn folder and addressed all the required changes * Restructuring and Backbone implementation changes * address format issues * adding test cases * adding maskrcnn into workflow * Fix order of decorators and jax integer dtype error * Fix format * Fix tests for GPU runs * Revert keras version to 3.3.3 in build system * Avoid TimeDistributed layers to fix for keras 3.3.3 and Acknowledge randomness for test --- .github/workflows/actions.yml | 1 + .kokoro/github/ubuntu/gpu/build.sh | 2 + keras_cv/api/models/__init__.py | 2 + keras_cv/api/models/mask_rcnn/__init__.py | 7 + .../api/models/object_detection/__init__.py | 1 + .../object_detection/mask_rcnn/__init__.py | 17 + .../object_detection/mask_rcnn/mask_head.py | 117 +++ .../mask_rcnn/mask_head_test.py | 56 + .../object_detection/mask_rcnn/mask_rcnn.py | 958 ++++++++++++++++++ .../mask_rcnn/mask_rcnn_test.py | 441 ++++++++ .../mask_rcnn/non_max_suppression.py | 565 +++++++++++ .../mask_rcnn/non_max_suppression_test.py | 74 ++ .../object_detection/mask_rcnn/roi_sampler.py | 320 ++++++ .../mask_rcnn/roi_sampler_test.py | 313 ++++++ 14 files changed, 2874 insertions(+) create mode 100644 keras_cv/api/models/mask_rcnn/__init__.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/__init__.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/mask_head.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/mask_head_test.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn_test.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression_test.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py create mode 100644 keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index e274eb6a34..12fbf5a9df 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -95,6 +95,7 @@ jobs: keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ keras_cv/src/models/object_detection/faster_rcnn \ + keras_cv/src/models/object_detection/mask_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ --durations 0 diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 00c442edc8..38bdd1f825 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -71,6 +71,7 @@ then keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ keras_cv/src/models/object_detection/faster_rcnn \ + keras_cv/src/models/object_detection/mask_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ @@ -88,6 +89,7 @@ else keras_cv/src/models/object_detection/retinanet \ keras_cv/src/models/object_detection/yolo_v8 \ keras_cv/src/models/object_detection/faster_rcnn \ + keras_cv/src/models/object_detection/mask_rcnn \ keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ diff --git a/keras_cv/api/models/__init__.py b/keras_cv/api/models/__init__.py index 6276b685f3..72e565bc05 100644 --- a/keras_cv/api/models/__init__.py +++ b/keras_cv/api/models/__init__.py @@ -7,6 +7,7 @@ from keras_cv.api.models import classification from keras_cv.api.models import faster_rcnn from keras_cv.api.models import feature_extractor +from keras_cv.api.models import mask_rcnn from keras_cv.api.models import object_detection from keras_cv.api.models import retinanet from keras_cv.api.models import segmentation @@ -209,6 +210,7 @@ from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( FasterRCNN, ) +from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/api/models/mask_rcnn/__init__.py b/keras_cv/api/models/mask_rcnn/__init__.py new file mode 100644 index 0000000000..e7a4108b07 --- /dev/null +++ b/keras_cv/api/models/mask_rcnn/__init__.py @@ -0,0 +1,7 @@ +"""DO NOT EDIT. + +This file was autogenerated. Do not edit it by hand, +since your modifications would be overwritten. +""" + +from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead diff --git a/keras_cv/api/models/object_detection/__init__.py b/keras_cv/api/models/object_detection/__init__.py index c49389c0b4..0bbde8bd47 100644 --- a/keras_cv/api/models/object_detection/__init__.py +++ b/keras_cv/api/models/object_detection/__init__.py @@ -7,6 +7,7 @@ from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import ( FasterRCNN, ) +from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import ( YOLOV8Detector, diff --git a/keras_cv/src/models/object_detection/mask_rcnn/__init__.py b/keras_cv/src/models/object_detection/mask_rcnn/__init__.py new file mode 100644 index 0000000000..adcab1399d --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead +from keras_cv.src.models.object_detection.mask_rcnn.mask_rcnn import MaskRCNN diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py new file mode 100644 index 0000000000..9aec892611 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py @@ -0,0 +1,117 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras + + +@keras_cv_export( + "keras_cv.models.mask_rcnn.MaskHead", + package="keras_cv.models.mask_rcnn", +) +class MaskHead(keras.layers.Layer): + """A Keras layer implementing the R-CNN Mask Head. + + The architecture is adopted from Matterport's Mask R-CNN implementation + https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py. + + Args: + num_classes: The number of object classes that are being detected, + excluding the background class. + stackwise_num_conv_filters: (Optional) a list of integers specifying + the number of filters for each convolutional layer. Defaults + to [256, 256]. + num_deconv_filters: (Optional) the number of filters to use in the + upsampling convolutional layer. Defaults to 256. + """ + + def __init__( + self, + num_classes, + stackwise_num_conv_filters=[256, 256], + num_deconv_filters=256, + **kwargs, + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.stackwise_num_conv_filters = stackwise_num_conv_filters + self.num_deconv_filters = num_deconv_filters + self.layers = [] + for num_filters in stackwise_num_conv_filters: + conv = keras.layers.Conv2D( + filters=num_filters, + kernel_size=3, + padding="same", + ) + batchnorm = keras.layers.BatchNormalization() + activation = keras.layers.Activation("relu") + self.layers.extend([conv, batchnorm, activation]) + + self.deconv = keras.layers.Conv2DTranspose( + num_deconv_filters, + kernel_size=2, + strides=2, + activation="relu", + padding="valid", + ) + # we do not use a final sigmoid activation, since we use + # from_logits=True during training + self.segmentation_mask_output = keras.layers.Conv2D( + num_classes + 1, + kernel_size=1, + strides=1, + activation="linear", + ) + + def call(self, feature_map, training=False): + # reshape batch and ROI axes into one axis to obtain a suitable + # shape for conv layers + num_rois = keras.ops.shape(feature_map)[1] + x = keras.ops.reshape(feature_map, (-1, *feature_map.shape[2:])) + for layer in self.layers: + x = layer(x, training=training) + x = self.deconv(x) + segmentation_mask = self.segmentation_mask_output(x) + segmentation_mask = keras.ops.reshape( + segmentation_mask, (-1, num_rois, *segmentation_mask.shape[1:]) + ) + return segmentation_mask + + def build(self, input_shape): + if input_shape[0] is None or input_shape[1] is None: + intermediate_shape = (None, *input_shape[2:]) + else: + intermediate_shape = ( + input_shape[0] * input_shape[1], + *input_shape[2:], + ) + for idx, num_filters in enumerate(self.stackwise_num_conv_filters): + self.layers[idx * 3].build(intermediate_shape) + intermediate_shape = tuple(intermediate_shape[:-1]) + (num_filters,) + self.layers[idx * 3 + 1].build(intermediate_shape) + self.deconv.build(intermediate_shape) + intermediate_shape = tuple(intermediate_shape[:-3]) + ( + intermediate_shape[-3] * 2, + intermediate_shape[-2] * 2, + self.num_deconv_filters, + ) + self.segmentation_mask_output.build(intermediate_shape) + self.built = True + + def get_config(self): + config = super().get_config() + config["num_classes"] = self.num_classes + config["stackwise_num_conv_filters"] = self.stackwise_num_conv_filters + config["num_deconv_filters"] = self.num_deconv_filters + return config diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_head_test.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_head_test.py new file mode 100644 index 0000000000..7d804a81a2 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_head_test.py @@ -0,0 +1,56 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from absl.testing import parameterized + +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.mask_rcnn import MaskHead +from keras_cv.src.tests.test_case import TestCase + + +class RCNNHeadTest(TestCase): + @parameterized.parameters( + (2, 256, 20, 7, 256), + (1, 512, 80, 14, 512), + ) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_head_output_shapes( + self, + batch_size, + num_rois, + num_classes, + roi_align_target_size, + num_filters, + ): + layer = MaskHead(num_classes) + + inputs = ops.ones( + shape=( + batch_size, + num_rois, + roi_align_target_size, + roi_align_target_size, + num_filters, + ) + ) + outputs = layer(inputs) + + mask_size = roi_align_target_size * 2 + + self.assertEqual( + (batch_size, num_rois, mask_size, mask_size, num_classes + 1), + outputs.shape, + ) diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py new file mode 100644 index 0000000000..9c67367493 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn.py @@ -0,0 +1,958 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tree + +from keras_cv.src import losses +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box import convert_format +from keras_cv.src.bounding_box.converters import _decode_deltas_to_boxes +from keras_cv.src.bounding_box.utils import _clip_boxes +from keras_cv.src.layers.object_detection.anchor_generator import ( + AnchorGenerator, +) +from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.src.layers.object_detection.roi_align import ROIAligner +from keras_cv.src.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.src.layers.object_detection.rpn_label_encoder import ( + RpnLabelEncoder, +) +from keras_cv.src.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.src.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.src.models.object_detection.faster_rcnn import RPNHead +from keras_cv.src.models.object_detection.mask_rcnn.mask_head import MaskHead +from keras_cv.src.models.object_detection.mask_rcnn.non_max_suppression import ( + NonMaxSuppression, +) +from keras_cv.src.models.object_detection.mask_rcnn.roi_sampler import ( + ROISampler, +) +from keras_cv.src.models.task import Task +from keras_cv.src.utils.train import get_feature_extractor + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +@keras_cv_export( + [ + "keras_cv.models.MaskRCNN", + "keras_cv.models.object_detection.MaskRCNN", + ] +) +class MaskRCNN(Task): + """A Keras model implementing the Mask R-CNN architecture. + + This model is compatible with Keras 3 only. Mask R-CNN is an extension of + Faster R-CNN, providing an additional mask head that predicts segmentation + masks. The constructor requires `num_classes`, `bounding_box_format`, + and a backbone. Optionally, a custom label encoder, and prediction decoder + may be provided. + + Args: + backbone: `keras.Model`. If the default `feature_pyramid` is used, + must implement the `pyramid_level_inputs` property with keys "P3", "P4", + and "P5" and layer names as values. A somewhat sensible backbone + to use in many cases is the: + `keras_cv.models.ResNetBackbone.from_preset("resnet50_imagenet")` + num_classes: the number of classes in your dataset excluding the + background class. Classes should be represented by integers in the + range [1, num_classes]. + bounding_box_format: The format of bounding boxes of input dataset. + Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + anchor_generator: (Optional) a `keras_cv.layers.AnchorGenerator`. If + provided, the anchor generator will be passed to both the + `label_encoder` and the `prediction_decoder`. Only to be used when + both `label_encoder` and `prediction_decoder` are both `None`. + Defaults to an anchor generator with the parameterization: + `strides=[2**i for i in range(3, 8)]`, + `scales=[2**x for x in [0, 1 / 3, 2 / 3]]`, + `sizes=[32.0, 64.0, 128.0, 256.0, 512.0]`, + and `aspect_ratios=[0.5, 1.0, 2.0]`. + anchor_scales: (Optional) list of anchor scales for + default anchor generator. + anchor_aspect_ratios: (Optional) list of anchor aspect ratios for + default anchor generator. + feature_pyramid: (Optional) A `keras.layers.Layer` that produces + a list of 4D feature maps (batch dimension included) + when called on the pyramid-level outputs of the `backbone`. + If not provided, the reference implementation from the paper will be used. + fpn_min_level: (Optional) the minimum level of the feature pyramid. + fpn_max_level: (Optional) the maximum level of the feature pyramid. + rpn_head: (Optional) A `keras.Layer` that performs regression and + classification(background or foreground) of the bounding boxes. + If not provided, a simple ConvNet with 3 layers will be used. + rpn_label_encoder_posistive_threshold: (Optional) the float threshold to set an + anchor to positive match to gt box. Values above it are positive matches. + rpn_label_encoder_negative_threshold: (Optional) the float threshold to set an + anchor to negative matchto gt box. Values below it are negative matches. + rpn_label_encoder_samples_per_image: (Optional) for each image, the number of + positive and negative samples to generate. + rpn_label_encoder_positive_fraction: (Optional) the fraction of positive samples to the total samples. + rcnn_head: (Optional) A `keras.Layer` that performs regression and + classification(final prediction) of the bounding boxes. + If not provided, a simple network with 2 dense layers with + box head and regression head will be used. + label_encoder: (Optional) a keras.Layer that accepts an image Tensor, a + bounding box Tensor and a bounding box class Tensor to its `call()` + method, and returns RetinaNet training targets. By default, a + KerasCV standard `RpnLabelEncoder` is created and used. + Results of this object's `call()` method are passed to the `loss` + object for `rpn_box_loss` and `rpn_classification_loss` the `y_true` + argument. + prediction_decoder: (Optional) A `keras.layers.Layer` that is + responsible for transforming RetinaNet predictions into usable + bounding box Tensors. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + num_max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + anchor_scales=[1], + anchor_aspect_ratios=[0.5, 1.0, 2.0], + feature_pyramid=None, + fpn_min_level=2, + fpn_max_level=5, + rpn_head=None, + rpn_filters=256, + rpn_kernel_size=3, + rpn_label_encoder_posistive_threshold=0.7, + rpn_label_encoder_negative_threshold=0.3, + rpn_label_encoder_samples_per_image=256, + rpn_label_encoder_positive_fraction=0.5, + rcnn_head=None, + mask_head=None, + num_sampled_rois=512, + label_encoder=None, + prediction_decoder=None, + num_max_decoder_detections=100, + *args, + **kwargs, + ): + # Backbone + extractor_levels = [ + f"P{level}" for level in range(fpn_min_level, fpn_max_level + 1) + ] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + + # Feature Pyramid + feature_pyramid = feature_pyramid or FeaturePyramid( + min_level=fpn_min_level, max_level=fpn_max_level + ) + + # Anchors + anchor_generator = ( + anchor_generator + or MaskRCNN.default_anchor_generator( + fpn_min_level, + fpn_max_level + 1, + anchor_scales, + anchor_aspect_ratios, + "yxyx", + ) + ) + + # RPN Head + num_anchors_per_location = len(anchor_scales) * len( + anchor_aspect_ratios + ) + rpn_head = rpn_head or RPNHead( + num_anchors_per_location=num_anchors_per_location, + num_filters=rpn_filters, + kernel_size=rpn_kernel_size, + ) + + # RoI Generator + roi_generator = ROIGenerator( + bounding_box_format="yxyx", + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + nms_from_logits=True, + name="roi_generator", + ) + + # RoI Align + roi_aligner = ROIAligner(bounding_box_format="yxyx", name="roi_align") + + # R-CNN Head + rcnn_head = rcnn_head or RCNNHead(num_classes, name="rcnn_head") + + # Mask Head + mask_head = mask_head or MaskHead(num_classes, name="mask_head") + + # Begin construction of forward pass + image_shape = feature_extractor.input_shape[1:] + if None in image_shape: + raise ValueError( + "Found `None` in image_shape, to build anchors `image_shape`" + "is required without any `None`. Make sure to pass " + "`image_shape` to the backbone preset while passing to" + "the Faster R-CNN detector." + ) + + images = keras.layers.Input( + image_shape, + name="images", + ) + + # Forward through backbone + backbone_outputs = feature_extractor(images) + + # Forward through FPN decoder + feature_map = feature_pyramid(backbone_outputs) + + # [P2, P3, P4, P5, P6] -> ([BS, num_anchors, 4], [BS, num_anchors, 1]) + # Pass through RPN Head + rpn_boxes, rpn_scores = rpn_head(feature_map) + + # Reshape and Concatenate all the output boxes of all levels + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + anchors = anchor_generator(image_shape=image_shape) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, roi_scores = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, "yxyx", image_shape) + + feature_map = roi_aligner(features=feature_map, boxes=rois) + + # Pass final feature map to the mask head for + # segmentation mask prediction + segmask_pred = mask_head(feature_map=feature_map) + + # Reshape the feature map [BS, H*W*K] + feature_map = keras.layers.Reshape( + target_shape=( + rois.shape[1], + (roi_aligner.target_size**2) * rpn_head.num_filters, + ) + )(feature_map) + + # Pass final feature map to RCNN Head for predictions + box_pred, cls_pred = rcnn_head(feature_map=feature_map) + + box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + [cls_pred] + ) + segmask_pred = keras.layers.Concatenate(axis=1, name="segmask")( + [segmask_pred] + ) + + inputs = {"images": images} + outputs = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + "segmask": segmask_pred, + } + + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.num_classes = num_classes + self.feature_extractor = feature_extractor + self.backbone = backbone + self.feature_pyramid = feature_pyramid + self.rpn_head = rpn_head + self.label_encoder = label_encoder or RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format=bounding_box_format, + positive_threshold=rpn_label_encoder_posistive_threshold, + negative_threshold=rpn_label_encoder_negative_threshold, + samples_per_image=rpn_label_encoder_samples_per_image, + positive_fraction=rpn_label_encoder_positive_fraction, + box_variance=BOX_VARIANCE, + ) + self.roi_generator = roi_generator + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = ROISampler( + roi_bounding_box_format="yxyx", + gt_bounding_box_format=bounding_box_format, + roi_matcher=self.box_matcher, + num_sampled_rois=num_sampled_rois, + ) + + self.roi_aligner = roi_aligner + self.rcnn_head = rcnn_head + self.mask_head = mask_head + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + max_detections=num_max_decoder_detections, + ) + self.build(backbone.input_shape) + + def compile( + self, + rpn_box_loss=None, + rpn_classification_loss=None, + box_loss=None, + classification_loss=None, + mask_loss=None, + weight_decay=0.0001, + loss=None, + metrics=None, + **kwargs, + ): + # We define a `compute_loss` method, which performs all steps necessary + # for computing the `rpn_box`, `rpn_classification`, `box`, + # `classification` and `segmask` outputs. We will then arrange these + # outputs in a dictionary and specify the relevant losses in a + # dictionary to have Keras compute the loss for each of the outputs. + if loss is not None: + raise ValueError( + "`MaskRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss`, `classification_loss` and " + "`mask_loss`. `loss` will be ignored during training." + ) + if ( + rpn_box_loss is None + or rpn_classification_loss is None + or box_loss is None + or classification_loss is None + or mask_loss is None + ): + raise ValueError( + "`MaskRCNN` expects all of `rpn_box_loss`, " + "`rpn_classification_loss`," + "`box_loss`, `classification_loss` and " + "`mask_loss` to be not `None`." + ) + + rpn_box_loss = _parse_box_loss(rpn_box_loss) + rpn_classification_loss = _parse_rpn_classification_loss( + rpn_classification_loss + ) + + if hasattr(rpn_classification_loss, "from_logits"): + if not rpn_classification_loss.from_logits: + raise ValueError( + "MaskRCNN.compile() expects `from_logits` to be True for " + "`rpn_classification_loss`. Got " + "`rpn_classification_loss.from_logits=" + f"{rpn_classification_loss.from_logits}`" + ) + box_loss = _parse_box_loss(box_loss) + classification_loss = _parse_classification_loss(classification_loss) + mask_loss = _parse_mask_loss(mask_loss) + + if hasattr(classification_loss, "from_logits"): + if not classification_loss.from_logits: + raise ValueError( + "MaskRCNN.compile() expects `from_logits` to be True for " + "`classification_loss`. Got " + "`classification_loss.from_logits=" + f"{classification_loss.from_logits}`" + ) + if hasattr(mask_loss, "from_logits"): + if not mask_loss.from_logits: + raise ValueError( + "MaskRCNN.compile() expects `from_logits` to be True for " + "`mask_loss`. Got " + "`mask_loss.from_logits=" + f"{mask_loss.from_logits}`" + ) + if hasattr(box_loss, "bounding_box_format"): + if box_loss.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Wrong `bounding_box_format` passed to `box_loss` in " + "`MaskRCNN.compile()`. Got " + "`box_loss.bounding_box_format=" + f"{box_loss.bounding_box_format}`, want " + "`box_loss.bounding_box_format=" + f"{self.bounding_box_format}`" + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.mask_loss = mask_loss + self.weight_decay = weight_decay + losses = { + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + "box": self.box_loss, + "classification": self.cls_loss, + "segmask": self.mask_loss, + } + self._has_user_metrics = metrics is not None and len(metrics) != 0 + self._user_metrics = metrics + super().compile(loss=losses, **kwargs) + + def compute_loss( + self, x, y, y_pred, sample_weight, training=True, **kwargs + ): + + # 1. Unpack the inputs + images = x + gt_boxes = y["boxes"] + if ops.ndim(y["classes"]) != 2: + raise ValueError( + "Expected 'classes' to be a Tensor of rank 2. " + f"Got y['classes'].shape={ops.shape(y['classes'])}." + ) + + gt_classes = y["classes"] + gt_classes = ops.expand_dims(gt_classes, axis=-1) + + gt_masks = y["segmask"] + + ####################################################################### + # Generate Anchors and Generate RPN Targets + ####################################################################### + local_batch = ops.shape(images)[0] + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + + # 2. Label with the anchors -- exclusive to compute_loss + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.label_encoder( + anchors_dict=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + gt_boxes=gt_boxes, + gt_classes=gt_classes, + ) + + # 3. Computing the weights + rpn_box_weights /= ( + self.label_encoder.samples_per_image * local_batch * 0.25 + ) + rpn_cls_weights /= self.label_encoder.samples_per_image * local_batch + + ####################################################################### + # Call Backbone, FPN and RPN Head + ####################################################################### + + backbone_outputs = self.feature_extractor(images) + feature_map = self.feature_pyramid(backbone_outputs) + rpn_boxes, rpn_scores = self.rpn_head(feature_map) + + for lvl in rpn_boxes: + rpn_boxes[lvl] = keras.layers.Reshape(target_shape=(-1, 4))( + rpn_boxes[lvl] + ) + + for lvl in rpn_scores: + rpn_scores[lvl] = keras.layers.Reshape(target_shape=(-1, 1))( + rpn_scores[lvl] + ) + + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )(tree.flatten(rpn_scores)) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + tree.flatten(rpn_boxes) + ) + + ####################################################################### + # Generate RoI's and RoI Sampling + ####################################################################### + + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=training + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + + # 4. Stop gradient from flowing into the ROI + # -- exclusive to compute_loss + rois = ops.stop_gradient(rois) + + # 5. Sample the ROIS -- exclusive to compute_loss + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + segmask_targets, + segmask_weights, + ) = self.roi_sampler(rois, gt_boxes, gt_classes, gt_masks) + + cls_targets = ops.squeeze(cls_targets, axis=-1) + cls_weights = ops.squeeze(cls_weights, axis=-1) + + # 6. Box and class weights -- exclusive to compute loss + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + cls_targets_numeric = cls_targets + cls_targets = ops.one_hot(cls_targets, num_classes=self.num_classes + 1) + + ####################################################################### + # Call RoI Aligner and RCNN Head + ####################################################################### + + feature_map = self.roi_aligner(features=feature_map, boxes=rois) + + segmask_pred = self.mask_head(feature_map=feature_map) + # we only consider the mask prediction for the groundtruth class + segmask_pred = ops.reshape( + segmask_pred, (-1, *ops.shape(segmask_pred)[2:]) + ) + segmask_pred_ind = ops.reshape( + cls_targets_numeric, (ops.shape(segmask_pred)[0], 1, 1, -1) + ) + segmask_pred_ind = ops.cast(segmask_pred_ind, "int32") + segmask_pred = ops.take_along_axis( + segmask_pred, segmask_pred_ind, axis=-1 + ) + # flatten each ROI's segmask to perform averaging (instead of + # summation) over all pixels during loss computation + segmask_pred = ops.reshape( + segmask_pred, (local_batch, self.roi_sampler.num_sampled_rois, -1) + ) + segmask_targets = ops.reshape( + segmask_targets, (*segmask_targets.shape[:2], -1) + ) + + # [BS, H*W*K] + feature_map = ops.reshape( + feature_map, + newshape=ops.shape(rois)[:2] + (-1,), + ) + + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + "segmask": segmask_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + "segmask": segmask_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + "segmask": segmask_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if type(outputs) is tuple: + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + else: + return self.decode_predictions(outputs, args[-1]) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and MaskRCNN to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def decode_predictions(self, predictions, images): + image_shape = ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + rpn_boxes, rpn_scores = ( + predictions["rpn_box"], + predictions["rpn_classification"], + ) + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + boxes_delta=rpn_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator( + decoded_rpn_boxes, rpn_scores, training=False + ) + rois = _clip_boxes(rois, "yxyx", image_shape) + box_pred, cls_pred = predictions["box"], predictions["classification"] + + # box_pred is on "center_yxhw" format, convert to target format. + box_pred = _decode_deltas_to_boxes( + anchors=rois, + boxes_delta=box_pred, + anchor_format=self.roi_aligner.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + image_shape=image_shape, + ) + + box_pred = convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + image_shape=image_shape, + ) + cls_pred = ops.softmax(cls_pred) + cls_pred = ops.slice( + cls_pred, + start_indices=[0, 0, 1], + shape=[cls_pred.shape[0], cls_pred.shape[1], cls_pred.shape[2] - 1], + ) + + y_pred = self.prediction_decoder( + box_pred, + cls_pred, + mask_prediction=predictions.get("segmask"), + image_shape=image_shape, + ) + + y_pred["classes"] = ops.where( + y_pred["classes"] == -1, -1, y_pred["classes"] + 1 + ) + + y_pred["boxes"] = convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) + + segmask_pred = ops.sigmoid(y_pred["segmask"]) + y_pred["segmask"] = self.decode_segmentation_masks( + segmask_pred=segmask_pred, + class_pred=y_pred["classes"], + decoded_boxes=y_pred["boxes"], + bbox_format=self.bounding_box_format, + image_shape=image_shape, + ) + return y_pred + + def _resize_and_pad_mask( + self, segmask_pred, class_pred, decoded_boxes, image_shape + ): + num_rois = ops.shape(segmask_pred)[0] + image_height, image_width = image_shape[:2] + + # Initialize a list to store the padded masks + padded_masks_list = [] + + # Iterate over the batch and place the resized masks into the correct + # position + for i in range(num_rois): + bounding_box = ops.maximum(ops.cast(decoded_boxes[i], "int32"), 0) + bounding_box = ops.minimum( + bounding_box, [image_height, image_width] * 2 + ) + y1, x1, y2, x2 = ops.unstack(bounding_box) + box_height = y2 - y1 + box_width = x2 - x1 + + def do_resize(): + # Resize the mask to the size of the bounding box + resized_mask = ops.image.resize( + segmask_pred[i], size=(box_height, box_width) + ) + resized_mask = ops.squeeze(resized_mask, axis=-1) + + # Place the resized mask into the correct position + # in the final mask + padded_mask = ops.pad( + resized_mask, + ( + (y1, image_height - y1 - box_height), + (x1, image_width - x1 - box_width), + ), + ) + return padded_mask + + # Only consider bounding boxes for valid predictions + valid_boxes = ops.all( + ops.stack([class_pred[i] != -1, box_height > 0, box_width > 0]) + ) + padded_mask = ops.cond( + valid_boxes, + do_resize, + lambda: ops.zeros( + (image_height, image_width), dtype=segmask_pred.dtype + ), + ) + + # Append the padded mask to the list + padded_masks_list.append(padded_mask) + + # Stack the list of masks into a single tensor + final_masks = ops.max(padded_masks_list, axis=0) + + return final_masks + + def decode_segmentation_masks( + self, segmask_pred, class_pred, decoded_boxes, bbox_format, image_shape + ): + """Decode the predicted segmentation mask output, combining all + masks in one mask for each image.""" + + decoded_boxes = convert_format( + decoded_boxes, source=bbox_format, target="yxyx" + ) + # pick the mask prediction for the predicted class + segmask_pred = ops.take_along_axis( + segmask_pred, class_pred[:, :, None, None, None], axis=-1 + ) + + final_masks = [] + for i in range(segmask_pred.shape[0]): + # resize the mask according to the bounding box + image_masks = self._resize_and_pad_mask( + segmask_pred[i], class_pred[i], decoded_boxes[i], image_shape + ) + final_masks.append(image_masks) + + return ops.stack(final_masks, axis=0) + + def compute_metrics(self, x, y, y_pred, sample_weight): + metrics = {} + metrics.update(super().compute_metrics(x, {}, {}, sample_weight={})) + + if not self._has_user_metrics: + return metrics + + y_pred = self.decode_predictions(y_pred, x) + + for metric in self._user_metrics: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + for metric in self._user_metrics: + result = metric.result() + if isinstance(result, dict): + metrics.update(result) + else: + metrics[metric.name] = result + return metrics + + @staticmethod + def default_anchor_generator( + min_level, max_level, scales, aspect_ratios, bounding_box_format + ): + strides = {f"P{i}": 2**i for i in range(min_level, max_level + 1)} + sizes = {f"P{i}": 2 ** (3 + i) for i in range(min_level, max_level + 1)} + return AnchorGenerator( + bounding_box_format=bounding_box_format, + sizes=sizes, + aspect_ratios=aspect_ratios, + scales=scales, + strides=strides, + clip_boxes=True, + name="anchor_generator", + ) + + def get_config(self): + return { + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "backbone": keras.saving.serialize_keras_object(self.backbone), + "label_encoder": keras.saving.serialize_keras_object( + self.label_encoder + ), + "rpn_head": keras.saving.serialize_keras_object(self.rpn_head), + "prediction_decoder": self._prediction_decoder, + "rcnn_head": self.rcnn_head, + "mask_head": self.mask_head, + } + + @classmethod + def from_config(cls, config): + if "rpn_head" in config and isinstance(config["rpn_head"], dict): + config["rpn_head"] = keras.layers.deserialize(config["rpn_head"]) + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + if "rcnn_head" in config and isinstance(config["rcnn_head"], dict): + config["rcnn_head"] = keras.layers.deserialize(config["rcnn_head"]) + if "mask_head" in config and isinstance(config["mask_head"], dict): + config["mask_head"] = keras.layers.deserialize(config["mask_head"]) + + return super().from_config(config) + + +def _parse_box_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "smoothl1": + return losses.SmoothL1Loss(l1_cutoff=1.0, reduction="sum") + if loss.lower() == "huber": + return keras.losses.Huber(reduction="sum") + + raise ValueError( + "Expected `box_loss` to be either a Keras Loss, " + f"callable, or the string 'SmoothL1'. Got loss={loss}." + ) + + +def _parse_rpn_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + if loss.lower() == "binarycrossentropy": + return keras.losses.BinaryCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `rpn_classification_loss` to be either BinaryCrossentropy" + f" loss callable, or the string 'BinaryCrossentropy'. Got loss={loss}." + ) + + +def _parse_classification_loss(loss): + # support arbitrary callables + if not isinstance(loss, str): + return loss + + # case insensitive comparison + if loss.lower() == "focal": + return losses.FocalLoss(reduction="sum", from_logits=True) + if loss.lower() == "categoricalcrossentropy": + return keras.losses.CategoricalCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `classification_loss` to be either a Keras Loss, " + f"callable, or the string 'Focal', CategoricalCrossentropy'. " + f"Got loss={loss}." + ) + + +def _parse_mask_loss(loss): + if not isinstance(loss, str): + # support arbitrary callables + return loss + + if loss.lower() == "binarycrossentropy": + return keras.losses.BinaryCrossentropy( + reduction="sum", from_logits=True + ) + + raise ValueError( + f"Expected `mask_loss` to be either BinaryCrossentropy" + f" loss callable, or the string 'BinaryCrossentropy'. Got loss={loss}." + ) + + +def unpack_input(data): + if type(data) is dict: + return data["images"], data["bounding_boxes"] + else: + return data diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn_test.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn_test.py new file mode 100644 index 0000000000..f577b7e3ec --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_rcnn_test.py @@ -0,0 +1,441 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import tensorflow as tf + +import keras_cv +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.models.object_detection.mask_rcnn import MaskRCNN +from keras_cv.src.tests.test_case import TestCase + + +def _create_bounding_box_segmask_dataset( + bounding_box_format, + image_shape=(512, 512, 3), + use_dictionary_box_format=False, +): + # Just about the easiest dataset you can have, all classes are 0, all boxes + # are exactly the same. [1, 1, 2, 2] are the coordinates in xyxy. + # segmentation masks cover the entire bounding box of the respective object + xs = np.random.normal(size=(1,) + image_shape) + xs = np.tile(xs, [5, 1, 1, 1]) + + y_classes = np.zeros((5, 3), "float32") + + ys = np.array( + [ + [0.1, 0.1, 0.23, 0.23], + [0.67, 0.75, 0.23, 0.23], + [0.25, 0.25, 0.23, 0.23], + ], + "float32", + ) + + ys = np.expand_dims(ys, axis=0) + + ys_yxyx = ops.convert_to_numpy( + keras_cv.bounding_box.convert_format( + ys, + source="rel_xywh", + target="yxyx", + images=xs, + dtype="float32", + ) + ) + ys_yxyx = ys_yxyx.astype(int) + segmask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8) + for object_idx, (obj_y1, obj_x1, obj_y2, obj_x2) in enumerate(ys_yxyx[0]): + segmask[obj_y1:obj_y2, obj_x1:obj_x2] = object_idx + 1 + segmask = np.expand_dims(segmask, axis=0) + + ys = np.tile(ys, [5, 1, 1]) + segmask = np.tile(segmask, [5, 1, 1]) + ys = ops.convert_to_numpy( + keras_cv.bounding_box.convert_format( + ys, + source="rel_xywh", + target=bounding_box_format, + images=xs, + dtype="float32", + ) + ) + num_dets = np.ones([5]) + + if use_dictionary_box_format: + return tf.data.Dataset.from_tensor_slices( + { + "images": xs, + "bounding_boxes": { + "boxes": ys, + "classes": y_classes, + "num_dets": num_dets, + "segmask": segmask, + }, + } + ).batch(5, drop_remainder=True) + else: + return xs, {"boxes": ys, "classes": y_classes, "segmask": segmask} + + +class MaskRCNNTest(TestCase): + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_rcnn_construction(self): + mask_rcnn = MaskRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + + @pytest.mark.extra_large() + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_rcnn_call(self): + mask_rcnn = MaskRCNN( + num_classes=3, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + images = np.random.uniform(size=(1, 32, 32, 3)) + _ = mask_rcnn(images) + _ = mask_rcnn.predict(images) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_wrong_logits(self): + mask_rcnn = MaskRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + + with self.assertRaisesRegex( + ValueError, + "from_logits", + ): + mask_rcnn.compile( + optimizer=keras.optimizers.SGD(learning_rate=0.25), + box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + rpn_box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + rpn_classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + mask_loss="BinaryCrossentropy", + ) + with self.assertRaisesRegex( + ValueError, + "from_logits", + ): + mask_rcnn.compile( + optimizer=keras.optimizers.SGD(learning_rate=0.25), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + ) + + @pytest.mark.large() + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_weights_contained_in_trainable_variables(self): + bounding_box_format = "xyxy" + mask_rcnn = MaskRCNN( + num_classes=80, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + mask_rcnn.backbone.trainable = False + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_segmask_dataset( + bounding_box_format, image_shape=(32, 32, 3) + ) + + # call once + _ = mask_rcnn(xs) + self.assertEqual(len(mask_rcnn.trainable_variables), 42) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_no_nans(self): + mask_rcnn = MaskRCNN( + num_classes=5, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=16, + ) + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + + # only a -1 box + xs = np.ones((1, 32, 32, 3), "float32") + ys = { + "classes": np.array([[-1]], "float32"), + "boxes": np.array([[[0, 0, 0, 0]]], "float32"), + "segmask": np.zeros((1, 32, 32), dtype="float32"), + } + ds = tf.data.Dataset.from_tensor_slices((xs, ys)) + ds = ds.repeat(1) + ds = ds.batch(1, drop_remainder=True) + mask_rcnn.fit(ds, epochs=1) + + weights = mask_rcnn.get_weights() + for weight in weights: + self.assertFalse(ops.any(ops.isnan(weight))) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_weights_change(self): + mask_rcnn = MaskRCNN( + num_classes=3, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + num_sampled_rois=16, + ) + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + + ds = _create_bounding_box_segmask_dataset( + "xyxy", image_shape=(128, 128, 3), use_dictionary_box_format=True + ) + + # call once + _ = mask_rcnn(ops.ones((1, 128, 128, 3))) + original_fpn_weights = mask_rcnn.feature_pyramid.get_weights() + original_rpn_head_weights = mask_rcnn.rpn_head.get_weights() + original_rcnn_head_weights = mask_rcnn.rcnn_head.get_weights() + original_mask_head_weights = mask_rcnn.mask_head.get_weights() + + mask_rcnn.fit(ds, epochs=1) + fpn_after_fit = mask_rcnn.feature_pyramid.get_weights() + rpn_head_after_fit_weights = mask_rcnn.rpn_head.get_weights() + rcnn_head_after_fit_weights = mask_rcnn.rcnn_head.get_weights() + mask_head_after_fit_weights = mask_rcnn.mask_head.get_weights() + + for w1, w2 in zip( + original_rcnn_head_weights, + rcnn_head_after_fit_weights, + ): + self.assertNotAllClose(w1, w2) + for w1, w2 in zip( + original_mask_head_weights, + mask_head_after_fit_weights, + ): + self.assertNotAllClose(w1, w2) + for w1, w2 in zip( + original_rpn_head_weights, rpn_head_after_fit_weights + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip(original_fpn_weights, fpn_after_fit): + self.assertNotAllClose(w1, w2) + + @pytest.mark.large # Saving is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_saved_model(self): + model = MaskRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + ) + input_batch = ops.ones(shape=(1, 32, 32, 3)) + model_output = model(input_batch) + save_path = os.path.join(self.get_temp_dir(), "mask_rcnn.keras") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, MaskRCNN) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose( + tf.nest.map_structure(ops.convert_to_numpy, model_output), + tf.nest.map_structure(ops.convert_to_numpy, restored_output), + ) + + @pytest.mark.large + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_rcnn_infer(self): + model = MaskRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + ) + images = ops.ones((1, 128, 128, 3)) + outputs = model(images, training=False) + # 1000 proposals in inference + self.assertAllEqual([1, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([1, 1000, 4], outputs["box"].shape) + self.assertAllEqual([1, 1000, 14, 14, 81], outputs["segmask"].shape) + + @pytest.mark.large + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_rcnn_train(self): + model = MaskRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(128, 128, 3) + ), + ) + images = ops.ones((1, 128, 128, 3)) + outputs = model(images, training=True) + self.assertAllEqual([1, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([1, 1000, 4], outputs["box"].shape) + self.assertAllEqual([1, 1000, 14, 14, 81], outputs["segmask"].shape) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_invalid_compile(self): + model = MaskRCNN( + num_classes=80, + bounding_box_format="yxyx", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=256, + ) + with self.assertRaisesRegex(ValueError, "expects"): + model.compile(rpn_box_loss="binary_crossentropy") + with self.assertRaisesRegex(ValueError, "from_logits"): + model.compile( + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss=keras.losses.BinaryCrossentropy( + from_logits=False + ), + mask_loss="BinaryCrossentropy", + ) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_mask_rcnn_with_dictionary_input_format(self): + mask_rcnn = MaskRCNN( + num_classes=3, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=16, + ) + + images, boxes = _create_bounding_box_segmask_dataset( + "xywh", image_shape=(32, 32, 3) + ) + dataset = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(1, drop_remainder=True) + + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + + mask_rcnn.fit(dataset, epochs=1) + + @pytest.mark.extra_large # Fit is slow, so mark these large. + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_fit_with_no_valid_gt_bbox(self): + bounding_box_format = "xywh" + mask_rcnn = MaskRCNN( + num_classes=2, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(32, 32, 3) + ), + num_sampled_rois=16, + ) + + mask_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="CategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + mask_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_segmask_dataset( + bounding_box_format, image_shape=(32, 32, 3) + ) + xs = ops.convert_to_tensor(xs) + # Make all bounding_boxes invalid and filter them out + ys["classes"] = -ops.ones_like(ys["classes"]) + + mask_rcnn.fit(x=xs, y=ys, epochs=1, batch_size=1) + + +# TODO: add presets test cases once model training is done. diff --git a/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py new file mode 100644 index 0000000000..2ae6f84bbe --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression.py @@ -0,0 +1,565 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import tensorflow as tf + +from keras_cv.src import bounding_box +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 + +EPSILON = 1e-8 + + +@keras.utils.register_keras_serializable(package="keras_cv") +class NonMaxSuppression(keras.layers.Layer): + """A Keras layer that decodes predictions of an object detection model. + + Args: + bounding_box_format: The format of bounding boxes of input dataset. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box + formats. + from_logits: boolean, True means input score is logits, False means + confidence. + iou_threshold: a float value in the range [0, 1] representing the minimum + IoU threshold for two boxes to be considered same for suppression. + Defaults to 0.5. + confidence_threshold: a float value in the range [0, 1]. All boxes with + confidence below this value will be discarded, defaults to 0.5. + max_detections: the maximum detections to consider after nms is applied. A + large number may trigger significant memory overhead, defaults to 100. + """ # noqa: E501 + + def __init__( + self, + bounding_box_format, + from_logits, + iou_threshold=0.5, + confidence_threshold=0.5, + max_detections=100, + **kwargs, + ): + super().__init__(**kwargs) + self.bounding_box_format = bounding_box_format + self.from_logits = from_logits + self.iou_threshold = iou_threshold + self.confidence_threshold = confidence_threshold + self.max_detections = max_detections + self.built = True + + def call( + self, + box_prediction, + class_prediction, + mask_prediction=None, + images=None, + image_shape=None, + ): + """Accepts images and raw predictions, and returns bounding box + predictions. + + Args: + box_prediction: Dense Tensor of shape [batch, boxes, 4] in the + `bounding_box_format` specified in the constructor. + class_prediction: Dense Tensor of shape [batch, boxes, num_classes]. + mask_prediction: Dense Tensor of shape [batch, boxes, mask_height, + mask_width]. + """ + target_format = "yxyx" + if bounding_box.is_relative(self.bounding_box_format): + target_format = bounding_box.as_relative(target_format) + + box_prediction = bounding_box.convert_format( + box_prediction, + source=self.bounding_box_format, + target=target_format, + images=images, + image_shape=image_shape, + ) + if self.from_logits: + class_prediction = ops.sigmoid(class_prediction) + + confidence_prediction = ops.max(class_prediction, axis=-1) + + if not keras_3() or keras.backend.backend() == "tensorflow": + idx, valid_det = tf.image.non_max_suppression_padded( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + pad_to_max_output_size=True, + sorted_input=False, + ) + elif keras.backend.backend() == "torch": + # Since TorchVision has a nice efficient NMS op, we might as well + # use it! + import torchvision + + batch_size = box_prediction.shape[0] + idx = ops.zeros((batch_size, self.max_detections)) + valid_det = ops.zeros((batch_size), "int32") + + for batch_idx in range(batch_size): + conf_mask = ( + confidence_prediction[batch_idx] > self.confidence_threshold + ) + conf_mask_idx = ops.squeeze(ops.nonzero(conf_mask), axis=0) + conf_i = confidence_prediction[batch_idx][conf_mask] + box_i = box_prediction[batch_idx][conf_mask] + + idx_i = torchvision.ops.nms( + box_i, conf_i, iou_threshold=self.iou_threshold + ) + + idx_i = conf_mask_idx[idx_i] + + num_boxes = idx_i.shape[0] + if num_boxes >= self.max_detections: + idx_i = idx_i[: self.max_detections] + num_boxes = self.max_detections + + valid_det[batch_idx] = ops.cast(ops.size(idx_i), "int32") + idx[batch_idx, :num_boxes] = idx_i + else: + idx, valid_det = non_max_suppression( + box_prediction, + confidence_prediction, + max_output_size=self.max_detections, + iou_threshold=self.iou_threshold, + score_threshold=self.confidence_threshold, + ) + + box_prediction = ops.take_along_axis( + box_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + box_prediction = ops.reshape( + box_prediction, (-1, self.max_detections, 4) + ) + confidence_prediction = ops.take_along_axis( + confidence_prediction, idx, axis=1 + ) + class_prediction = ops.take_along_axis( + class_prediction, ops.expand_dims(idx, axis=-1), axis=1 + ) + + if mask_prediction is not None: + mask_prediction = ops.take_along_axis( + mask_prediction, idx[..., None, None, None], axis=1 + ) + + box_prediction = bounding_box.convert_format( + box_prediction, + source=target_format, + target=self.bounding_box_format, + images=images, + image_shape=image_shape, + ) + bounding_boxes = { + "boxes": box_prediction, + "confidence": confidence_prediction, + "classes": ops.argmax(class_prediction, axis=-1), + "num_detections": valid_det, + } + if mask_prediction is not None: + bounding_boxes["segmask"] = mask_prediction + + # this is required to comply with KerasCV bounding box format. + return bounding_box.mask_invalid_detections( + bounding_boxes, output_ragged=False + ) + + def get_config(self): + config = { + "bounding_box_format": self.bounding_box_format, + "from_logits": self.from_logits, + "iou_threshold": self.iou_threshold, + "confidence_threshold": self.confidence_threshold, + "max_detections": self.max_detections, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + +def non_max_suppression( + boxes, + scores, + max_output_size, + iou_threshold=0.5, + score_threshold=0.0, + tile_size=512, +): + # Box format must be yxyx + """Non-maximum suppression. + Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458 + + Args: + boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4]. + Dimensions except the last two are batch dimensions. The last dimension + represents box coordinates in yxyx format. + scores: a tensor of rank 1 or higher with a shape of [..., num_boxes]. + max_output_size: a scalar integer tensor representing the maximum number + of boxes to be selected by non max suppression. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IoU (intersection over union). + score_threshold: a float representing the threshold for box scores. Boxes + with a score that is not larger than this threshold will be suppressed. + tile_size: an integer representing the number of boxes in a tile, i.e., + the maximum number of boxes per image that can be used to suppress other + boxes in parallel; larger tile_size means larger parallelism and + potentially more redundant work. + + Returns: + idx: a tensor with a shape of [..., num_boxes] representing the + indices selected by non-max suppression. The leading dimensions + are the batch dimensions of the input boxes. All numbers are within + [0, num_boxes). For each image (i.e., idx[i]), only the first num_valid[i] + indices (i.e., idx[i][:num_valid[i]]) are valid. + num_valid: a tensor of rank 0 or higher with a shape of [...] + representing the number of valid indices in idx. Its dimensions are the + batch dimensions of the input boxes. + """ # noqa: E501 + + def _sort_scores_and_boxes(scores, boxes): + """Sort boxes based their score from highest to lowest. + + Args: + scores: a tensor with a shape of [batch_size, num_boxes] representing + the scores of boxes. + boxes: a tensor with a shape of [batch_size, num_boxes, 4] representing + the boxes. + + Returns: + sorted_scores: a tensor with a shape of [batch_size, num_boxes] + representing the sorted scores. + sorted_boxes: a tensor representing the sorted boxes. + sorted_scores_indices: a tensor with a shape of [batch_size, num_boxes] + representing the index of the scores in a sorted descending order. + """ # noqa: E501 + with ops.name_scope("sort_scores_and_boxes"): + sorted_scores_indices = ops.flip( + ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1 + ) + sorted_scores = ops.take_along_axis( + scores, + sorted_scores_indices, + axis=1, + ) + sorted_boxes = ops.take_along_axis( + boxes, + ops.expand_dims(sorted_scores_indices, axis=-1), + axis=1, + ) + return sorted_scores, sorted_boxes, sorted_scores_indices + + batch_dims = ops.shape(boxes)[:-2] + num_boxes = boxes.shape[-2] + boxes = ops.reshape(boxes, [-1, num_boxes, 4]) + scores = ops.reshape(scores, [-1, num_boxes]) + batch_size = boxes.shape[0] + if score_threshold != float("-inf"): + with ops.name_scope("filter_by_score"): + score_mask = ops.cast(scores > score_threshold, scores.dtype) + scores *= score_mask + box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2) + boxes *= box_mask + + scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes) + + pad = ( + math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size + - num_boxes + ) + boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]]) + scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]]) + num_boxes_after_padding = num_boxes + pad + num_iterations = num_boxes_after_padding // tile_size + + def _loop_cond(unused_boxes, unused_threshold, output_size, idx): + return ops.logical_and( + ops.min(output_size) < ops.cast(max_output_size, "int32"), + ops.cast(idx, "int32") < num_iterations, + ) + + def suppression_loop_body(boxes, iou_threshold, output_size, idx): + return _suppression_loop_body( + boxes, iou_threshold, output_size, idx, tile_size + ) + + selected_boxes, _, output_size, _ = ops.while_loop( + _loop_cond, + suppression_loop_body, + [ + boxes, + iou_threshold, + ops.zeros([batch_size], "int32"), + ops.array(0), + ], + ) + num_valid = ops.minimum(output_size, max_output_size) + idx = num_boxes_after_padding - ops.cast( + ops.top_k( + ops.cast(ops.any(selected_boxes > 0, [2]), "int32") + * ops.cast( + ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0), + "int32", + ), + max_output_size, + )[0], + "int32", + ) + idx = ops.minimum(idx, num_boxes - 1) + + index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32") + take_along_axis_idx = ops.reshape( + idx + ops.expand_dims(index_offsets, 1), [-1] + ) + + # TODO(ianstenbit): Fix bug in tfnp.take_along_axis that causes this hack. + # (This will be removed anyway when we use built-in NMS for TF.) + if keras_3() and keras.backend.backend() != "tensorflow": + idx = ops.take_along_axis( + ops.reshape(sorted_indices, [-1]), take_along_axis_idx + ) + else: + import tensorflow as tf + + idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx) + idx = ops.reshape(idx, [batch_size, -1]) + + invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32") + idx_index = ops.cast( + ops.expand_dims(ops.arange(max_output_size), 0), "int32" + ) + num_valid_expanded = ops.expand_dims(num_valid, 1) + idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index) + + num_valid = ops.reshape(num_valid, batch_dims) + return idx, num_valid + + +def _bbox_overlap(boxes_a, boxes_b): + """Calculates the overlap (iou - intersection over union) between boxes_a and boxes_b. + + Args: + boxes_a: a tensor with a shape of [batch_size, N, 4]. N is the number of + boxes per image. The last dimension is the pixel coordinates in + [ymin, xmin, ymax, xmax] form. + boxes_b: a tensor with a shape of [batch_size, M, 4]. M is the number of + boxes. The last dimension is the pixel coordinates in + [ymin, xmin, ymax, xmax] form. + + Returns: + intersection_over_union: a tensor with as a shape of [batch_size, N, M], + representing the ratio of intersection area over union area (IoU) between + two boxes + """ # noqa: E501 + with ops.name_scope("bbox_overlap"): + if len(boxes_a.shape) == 4: + boxes_a = ops.squeeze(boxes_a, axis=0) + a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2) + b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2) + + # Calculates the intersection area. + i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1])) + i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1])) + i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1])) + i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1])) + i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum( + (i_ymax - i_ymin), 0 + ) + + # Calculates the union area. + a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min) + b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min) + + # Adds a small epsilon to avoid divide-by-zero. + u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON + + intersection_over_union = i_area / u_area + + return intersection_over_union + + +def _self_suppression(iou, _, iou_sum, iou_threshold): + """Suppress boxes in the same tile. + + Compute boxes that cannot be suppressed by others (i.e., + can_suppress_others), and then use them to suppress boxes in the same tile. + + Args: + iou: a tensor of shape [batch_size, num_boxes_with_padding] representing + intersection over union. + iou_sum: a scalar tensor. + iou_threshold: a scalar tensor. + + Returns: + iou_suppressed: a tensor of shape [batch_size, num_boxes_with_padding]. + iou_diff: a scalar tensor representing whether any box is supressed in + this step. + iou_sum_new: a scalar tensor of shape [batch_size] that represents + the iou sum after suppression. + iou_threshold: a scalar tensor. + """ # noqa: E501 + batch_size = ops.shape(iou)[0] + can_suppress_others = ops.cast( + ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]), + iou.dtype, + ) + iou_after_suppression = ( + ops.reshape( + ops.cast( + ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype + ), + [batch_size, -1, 1], + ) + * iou + ) + iou_sum_new = ops.sum(iou_after_suppression, [1, 2]) + return [ + iou_after_suppression, + ops.any(iou_sum - iou_sum_new > iou_threshold), + iou_sum_new, + iou_threshold, + ] + + +def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size): + """Suppress boxes between different tiles. + + Args: + boxes: a tensor of shape [batch_size, num_boxes_with_padding, 4] + box_slice: a tensor of shape [batch_size, tile_size, 4] + iou_threshold: a scalar tensor + inner_idx: a scalar tensor representing the tile index of the tile + that is used to supress box_slice + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: unchanged boxes as input + box_slice_after_suppression: box_slice after suppression + iou_threshold: unchanged + """ + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + inner_idx * tile_size, + (inner_idx + 1) * tile_size - 1, + tile_size, + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + new_slice = ops.expand_dims( + ops.take_along_axis(boxes, slice_index, axis=1), 0 + ) + iou = _bbox_overlap(new_slice, box_slice) + box_slice_after_suppression = ( + ops.expand_dims( + ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2 + ) + * box_slice + ) + return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1 + + +def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size): + """Process boxes in the range [idx*tile_size, (idx+1)*tile_size). + + Args: + boxes: a tensor with a shape of [batch_size, anchors, 4]. + iou_threshold: a float representing the threshold for deciding whether boxes + overlap too much with respect to IOU. + output_size: an int32 tensor of size [batch_size]. Representing the number + of selected boxes for each batch. + idx: an integer scalar representing induction variable. + tile_size: an integer representing the number of boxes in a tile + + Returns: + boxes: updated boxes. + iou_threshold: pass down iou_threshold to the next iteration. + output_size: the updated output_size. + idx: the updated induction variable. + """ # noqa: E501 + with ops.name_scope("suppression_loop_body"): + num_tiles = boxes.shape[1] // tile_size + batch_size = boxes.shape[0] + + def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx): + return _cross_suppression( + boxes, box_slice, iou_threshold, inner_idx, tile_size + ) + + # Iterates over tiles that can possibly suppress the current tile. + slice_index = ops.expand_dims( + ops.expand_dims( + ops.cast( + ops.linspace( + idx * tile_size, (idx + 1) * tile_size - 1, tile_size + ), + "int32", + ), + axis=0, + ), + axis=-1, + ) + box_slice = ops.take_along_axis(boxes, slice_index, axis=1) + _, box_slice, _, _ = ops.while_loop( + lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, + cross_suppression_func, + [boxes, box_slice, iou_threshold, ops.array(0)], + ) + + # Iterates over the current tile to compute self-suppression. + iou = _bbox_overlap(box_slice, box_slice) + mask = ops.expand_dims( + ops.reshape(ops.arange(tile_size), [1, -1]) + > ops.reshape(ops.arange(tile_size), [-1, 1]), + 0, + ) + iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype) + suppressed_iou, _, _, _ = ops.while_loop( + lambda _iou, loop_condition, _iou_sum, _: loop_condition, + _self_suppression, + [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold], + ) + suppressed_box = ops.sum(suppressed_iou, 1) > 0 + box_slice *= ops.expand_dims( + 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2 + ) + + # Uses box_slice to update the input boxes. + mask = ops.reshape( + ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype), + [1, -1, 1, 1], + ) + boxes = ops.tile( + ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * ( + 1 - mask + ) + boxes = ops.reshape(boxes, [batch_size, -1, 4]) + + # Updates output_size. + output_size += ops.cast( + ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32" + ) + return boxes, iou_threshold, output_size, idx + 1 diff --git a/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression_test.py b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression_test.py new file mode 100644 index 0000000000..7709e04e4a --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/non_max_suppression_test.py @@ -0,0 +1,74 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_cv.src.backend import ops +from keras_cv.src.models.object_detection.mask_rcnn.non_max_suppression import ( + NonMaxSuppression, +) +from keras_cv.src.tests.test_case import TestCase + + +class NonMaxSupressionTest(TestCase): + def test_confidence_threshold(self): + boxes = np.random.uniform(low=0, high=1, size=(2, 5, 4)) + classes = ops.expand_dims( + np.array( + [[0.1, 0.1, 0.4, 0.9, 0.5], [0.7, 0.5, 0.3, 0.0, 0.0]], + "float32", + ), + axis=-1, + ) + + nms = NonMaxSuppression( + bounding_box_format="yxyx", + from_logits=False, + iou_threshold=1.0, + confidence_threshold=0.45, + max_detections=2, + ) + + outputs = nms(boxes, classes) + + self.assertAllClose( + outputs["boxes"], [boxes[0][-2:, ...], boxes[1][:2, ...]] + ) + self.assertAllClose(outputs["classes"], [[0.0, 0.0], [0.0, 0.0]]) + self.assertAllClose(outputs["confidence"], [[0.9, 0.5], [0.7, 0.5]]) + + def test_max_detections(self): + boxes = np.random.uniform(low=0, high=1, size=(2, 5, 4)) + classes = ops.expand_dims( + np.array( + [[0.1, 0.1, 0.4, 0.5, 0.9], [0.7, 0.5, 0.3, 0.0, 0.0]], + "float32", + ), + axis=-1, + ) + + nms = NonMaxSuppression( + bounding_box_format="yxyx", + from_logits=False, + iou_threshold=1.0, + confidence_threshold=0.1, + max_detections=1, + ) + + outputs = nms(boxes, classes) + + self.assertAllClose( + outputs["boxes"], [boxes[0][-1:, ...], boxes[1][:1, ...]] + ) + self.assertAllClose(outputs["classes"], [[0.0], [0.0]]) + self.assertAllClose(outputs["confidence"], [[0.9], [0.7]]) diff --git a/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py new file mode 100644 index 0000000000..86e66cb7ba --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler.py @@ -0,0 +1,320 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.src import bounding_box +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box import iou +from keras_cv.src.layers.object_detection import box_matcher +from keras_cv.src.layers.object_detection import sampling +from keras_cv.src.utils import target_gather + + +@keras.utils.register_keras_serializable(package="keras_cv") +class ROISampler(keras.layers.Layer): + """ + Sample ROIs for loss related calculation. + + With proposals (ROIs) and ground truth, it performs the following: + 1) compute IOU similarity matrix + 2) match each proposal to ground truth box based on IOU + 3) samples positive matches and negative matches and return + + `append_gt_boxes` augments proposals with ground truth boxes. This is + useful in 2 stage detection networks during initialization where the + 1st stage often cannot produce good proposals for 2nd stage. Setting it to + True will allow it to generate more reasonable proposals at the beginning. + + `background_class` allow users to set the labels for background proposals. + Default is 0, where users need to manually shift the incoming `gt_classes` + if its range is [0, num_classes). + + Args: + roi_bounding_box_format: The format of roi bounding boxes. Refer + [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) + for more details on supported bounding box formats. + gt_bounding_box_format: The format of ground truth bounding boxes. + roi_matcher: a `BoxMatcher` object that matches proposals with ground + truth boxes. The positive match must be 1 and negative match must be -1. + Such assumption is not being validated here. + positive_fraction: the positive ratio w.r.t `num_sampled_rois`, defaults + to 0.25. + background_class: the background class which is used to map returned the + sampled ground truth which is classified as background. + num_sampled_rois: the number of sampled proposals per image for + further (loss) calculation, defaults to 256. + append_gt_boxes: boolean, whether gt_boxes will be appended to rois + before sample the rois, defaults to True. + mask_shape: The shape of segmentation masks used for training, + defaults to (14,14). + """ # noqa: E501 + + def __init__( + self, + roi_bounding_box_format: str, + gt_bounding_box_format: str, + roi_matcher: box_matcher.BoxMatcher, + positive_fraction: float = 0.25, + background_class: int = 0, + num_sampled_rois: int = 256, + append_gt_boxes: bool = True, + mask_shape=(14, 14), + **kwargs, + ): + super().__init__(**kwargs) + self.roi_bounding_box_format = roi_bounding_box_format + self.gt_bounding_box_format = gt_bounding_box_format + self.roi_matcher = roi_matcher + self.positive_fraction = positive_fraction + self.background_class = background_class + self.num_sampled_rois = num_sampled_rois + self.append_gt_boxes = append_gt_boxes + self.mask_shape = mask_shape + self.seed_generator = keras.random.SeedGenerator() + self.built = True + # for debugging. + self._positives = keras.metrics.Mean() + self._negatives = keras.metrics.Mean() + + def call(self, rois, gt_boxes, gt_classes, gt_masks=None): + """ + Args: + rois: [batch_size, num_rois, 4] + gt_boxes: [batch_size, num_gt, 4] + gt_classes: [batch_size, num_gt, 1] + gt_masks: [batch_size, num_gt, height, width] + Returns: + sampled_rois: [batch_size, num_sampled_rois, 4] + sampled_gt_boxes: [batch_size, num_sampled_rois, 4] + sampled_box_weights: [batch_size, num_sampled_rois, 1] + sampled_gt_classes: [batch_size, num_sampled_rois, 1] + sampled_class_weights: [batch_size, num_sampled_rois, 1] + sampled_gt_masks: + [batch_size, num_sampled_rois, mask_height, mask_width] + sampled_mask_weights: [batch_size, num_sampled_rois, 1] + """ + rois = bounding_box.convert_format( + rois, source=self.roi_bounding_box_format, target="yxyx" + ) + gt_boxes = bounding_box.convert_format( + gt_boxes, source=self.gt_bounding_box_format, target="yxyx" + ) + if self.append_gt_boxes: + # num_rois += num_gt + rois = ops.concatenate([rois, gt_boxes], axis=1) + num_rois = ops.shape(rois)[1] + if num_rois is None: + raise ValueError( + f"`rois` must have static shape, got {ops.shape(rois)}" + ) + if num_rois < self.num_sampled_rois: + raise ValueError( + "num_rois must be less than `num_sampled_rois` " + f"({self.num_sampled_rois}), got {num_rois}" + ) + # [batch_size, num_rois, num_gt] + similarity_mat = iou.compute_iou( + rois, gt_boxes, bounding_box_format="yxyx", use_masking=True + ) + # [batch_size, num_rois] | [batch_size, num_rois] + matched_gt_cols, matched_vals = self.roi_matcher(similarity_mat) + # [batch_size, num_rois] + positive_matches = ops.equal(matched_vals, 1) + negative_matches = ops.equal(matched_vals, -1) + self._positives.update_state( + ops.sum(ops.cast(positive_matches, "float32"), axis=-1) + ) + self._negatives.update_state( + ops.sum(ops.cast(negative_matches, "float32"), axis=-1) + ) + # [batch_size, num_rois, 1] + background_mask = ops.expand_dims( + ops.logical_not(positive_matches), axis=-1 + ) + # [batch_size, num_rois, 1] + matched_gt_classes = target_gather._target_gather( + gt_classes, matched_gt_cols + ) + # also set all background matches to `background_class` + matched_gt_classes = ops.where( + background_mask, + ops.cast( + self.background_class * ops.ones_like(matched_gt_classes), + gt_classes.dtype, + ), + matched_gt_classes, + ) + # [batch_size, num_rois, 4] + matched_gt_boxes = target_gather._target_gather( + gt_boxes, matched_gt_cols + ) + encoded_matched_gt_boxes = bounding_box._encode_box_to_deltas( + anchors=rois, + boxes=matched_gt_boxes, + anchor_format="yxyx", + box_format="yxyx", + variance=[0.1, 0.1, 0.2, 0.2], + ) + # also set all background matches to 0 coordinates + encoded_matched_gt_boxes = ops.where( + background_mask, + ops.zeros_like(matched_gt_boxes), + encoded_matched_gt_boxes, + ) + # [batch_size, num_rois] + sampled_indicators = sampling.balanced_sample( + positive_matches, + negative_matches, + self.num_sampled_rois, + self.positive_fraction, + seed=self.seed_generator, + ) + # [batch_size, num_sampled_rois] in the range of [0, num_rois) + sampled_indicators, sampled_indices = ops.top_k( + sampled_indicators, k=self.num_sampled_rois, sorted=True + ) + # [batch_size, num_sampled_rois, 4] + sampled_rois = target_gather._target_gather(rois, sampled_indices) + # [batch_size, num_sampled_rois, 4] + sampled_gt_boxes = target_gather._target_gather( + encoded_matched_gt_boxes, sampled_indices + ) + # [batch_size, num_sampled_rois, 1] + sampled_gt_classes = target_gather._target_gather( + matched_gt_classes, sampled_indices + ) + # [batch_size, num_sampled_rois, 1] + # all negative samples will be ignored in regression + sampled_box_weights = target_gather._target_gather( + ops.cast(positive_matches[..., None], gt_boxes.dtype), + sampled_indices, + ) + # [batch_size, num_sampled_rois, 1] + sampled_indicators = sampled_indicators[..., None] + sampled_class_weights = ops.cast(sampled_indicators, gt_classes.dtype) + + if gt_masks is not None: + sampled_gt_cols = target_gather._target_gather( + matched_gt_cols[:, :, None], sampled_indices + ) + + # [batch_size, num_sampled_rois, height, width] + cropped_and_resized_masks = crop_and_resize( + ops.expand_dims(gt_masks, axis=-1), + bounding_boxes=sampled_rois, + target_size=self.mask_shape, + ) + cropped_and_resized_masks = ops.squeeze( + cropped_and_resized_masks, axis=-1 + ) + + sampled_gt_masks = ops.equal( + cropped_and_resized_masks, sampled_gt_cols[..., None] + 1 + ) + sampled_gt_masks = ops.cast(sampled_gt_masks, "float32") + + # Mask weights: 1 for positive samples, 0 for background + sampled_mask_weights = sampled_box_weights + + sampled_data = ( + sampled_rois, + sampled_gt_boxes, + sampled_box_weights, + sampled_gt_classes, + sampled_class_weights, + ) + if gt_masks is not None: + sampled_data = sampled_data + ( + sampled_gt_masks, + sampled_mask_weights, + ) + return sampled_data + + def get_config(self): + config = super().get_config() + config["roi_bounding_box_format"] = self.roi_bounding_box_format + config["gt_bounding_box_format"] = self.gt_bounding_box_format + config["positive_fraction"] = self.positive_fraction + config["background_class"] = self.background_class + config["num_sampled_rois"] = self.num_sampled_rois + config["append_gt_boxes"] = self.append_gt_boxes + config["mask_shape"] = self.mask_shape + config["roi_matcher"] = self.roi_matcher.get_config() + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + roi_matcher_config = config.pop("roi_matcher") + roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config) + return cls(roi_matcher=roi_matcher, **config) + + +def crop_and_resize(images, bounding_boxes, target_size): + """ + A utility function to crop and resize bounding boxes from + images to a given size. + + `bounding_boxes` is expected to be in yxyx format. + """ + + num_images, num_boxes = ops.shape(bounding_boxes)[:2] + bounding_boxes = ops.cast(bounding_boxes, "int32") + channels = ops.shape(images)[3] + + cropped_and_resized_images = [] + for image_idx in range(num_images): + for box_idx in range(num_boxes): + y1, x1, y2, x2 = ops.unstack(bounding_boxes[image_idx, box_idx]) + # crop to the bounding box + slice_y = ops.maximum(y1, 0) + slice_x = ops.maximum(x1, 0) + cropped_image = ops.slice( + images[image_idx], + (slice_y, slice_x, ops.cast(0, slice_y.dtype)), + (y2 - slice_y, x2 - slice_x, channels), + ) + # pad if the bounding box goes beyond the image + pad_y = -ops.minimum(y1, 0) + pad_x = -ops.minimum(x1, 0) + cropped_image = ops.pad( + cropped_image, + ( + ( + pad_y, + ops.maximum(y2 - y1, 1) + - ops.shape(cropped_image)[0] + - pad_y, + ), + ( + pad_x, + ops.maximum(x2 - x1, 1) + - ops.shape(cropped_image)[1] + - pad_x, + ), + (0, 0), + ), + ) + # resize to the target size + resized_image = ops.image.resize( + cropped_image, target_size, interpolation="nearest" + ) + cropped_and_resized_images.append(resized_image) + + cropped_and_resized_images = ops.stack(cropped_and_resized_images, axis=0) + + target_shape = (num_images, num_boxes, *target_size, channels) + cropped_and_resized_images = ops.reshape( + cropped_and_resized_images, target_shape + ) + return cropped_and_resized_images diff --git a/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py new file mode 100644 index 0000000000..d382926855 --- /dev/null +++ b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py @@ -0,0 +1,313 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras_cv.src.backend import ops +from keras_cv.src.backend.config import keras_3 +from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.src.models.object_detection.mask_rcnn.roi_sampler import ( + ROISampler, +) +from keras_cv.src.tests.test_case import TestCase + + +class ROISamplerTest(TestCase): + @parameterized.parameters((0,), (1,), (2,)) + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler(self, mask_value): + box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=2, + append_gt_boxes=False, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + gt_masks = mask_value * np.ones((1, 20, 20), dtype=np.uint8) + _, sampled_gt_boxes, _, sampled_gt_classes, _, sampled_gt_masks, _ = ( + roi_sampler(rois, gt_boxes, gt_classes, gt_masks) + ) + # given we only choose 1 positive sample, and `append_label` is False, + # only the 2nd ROI is chosen. + expected_gt_boxes = np.array([[0.0, 0.0, 0, 0.0], [0.0, 0.0, 0, 0.0]]) + expected_gt_boxes = expected_gt_boxes[np.newaxis, ...] + # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. + expected_gt_classes = np.array([[10], [0]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + self.assertAllClose( + np.max(expected_gt_boxes), + np.max(ops.convert_to_numpy(sampled_gt_boxes)), + ) + self.assertAllClose( + np.min(expected_gt_classes), + np.min(ops.convert_to_numpy(sampled_gt_classes)), + ) + # the sampled mask is only set to 1 if the ground truth + # mask indicates object 2 + sampled_index = ops.where(sampled_gt_classes[0, :, 0] == 10)[0][0] + self.assertAllClose( + sampled_gt_masks[0, sampled_index], + (mask_value == 2) * np.ones((14, 14)), + ) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler_small_threshold(self): + self.skipTest( + "TODO: resolving flaky test, https://github.com/keras-team/keras-cv/issues/2336" # noqa + ) + box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=2, + append_gt_boxes=False, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + sampled_rois, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( + rois, gt_boxes, gt_classes + ) + # given we only choose 1 positive sample, and `append_label` is False, + # only the 2nd ROI is chosen. No negative samples exist given we + # select positive_threshold to be 0.1. (the minimum IOU is 1/7) + # given num_sampled_rois=2, it selects the 1st ROI as well. + expected_rois = np.array([[5, 5, 10, 10], [0.0, 0.0, 5.0, 5.0]]) + expected_rois = expected_rois[np.newaxis, ...] + # all ROIs are matched to the 2nd gt box. + # the boxes are encoded by dimensions, so the result is + # tx, ty = (5.1 - 5.0) / 5 = 0.02, tx, ty = (5.1 - 2.5) / 5 = 0.52 + # then divide by 0.1 as box variance. + expected_gt_boxes = ( + np.array([[0.02, 0.02, 0.0, 0.0], [0.52, 0.52, 0.0, 0.0]]) / 0.1 + ) + expected_gt_boxes = expected_gt_boxes[np.newaxis, ...] + # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. + expected_gt_classes = np.array([[10], [10]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + self.assertAllClose(np.max(expected_rois, 1), np.max(sampled_rois, 1)) + self.assertAllClose( + np.max(expected_gt_boxes, 1), + np.max(sampled_gt_boxes, 1), + ) + self.assertAllClose(expected_gt_classes, sampled_gt_classes) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler_large_threshold(self): + # the 2nd roi and 2nd gt box has IOU of 0.923, setting + # positive_threshold to 0.95 to ignore it. + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=2, + append_gt_boxes=False, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( + rois, gt_boxes, gt_classes + ) + # all ROIs are negative matches, so they are mapped to 0. + expected_gt_boxes = np.zeros([1, 2, 4], dtype=np.float32) + # only the 2nd ROI is chosen, and the negative ROI is mapped to 0. + expected_gt_classes = np.array([[0], [0]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + # self.assertAllClose(expected_rois, sampled_rois) + self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) + self.assertAllClose(expected_gt_classes, sampled_gt_classes) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler_large_threshold_custom_bg_class(self): + # the 2nd roi and 2nd gt box has IOU of 0.923, setting + # positive_threshold to 0.95 to ignore it. + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + background_class=-1, + num_sampled_rois=2, + append_gt_boxes=False, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( + rois, gt_boxes, gt_classes + ) + # all ROIs are negative matches, so they are mapped to 0. + expected_gt_boxes = np.zeros([1, 2, 4], dtype=np.float32) + # only the 2nd ROI is chosen, and the negative ROI is mapped to -1 from + # customization. + expected_gt_classes = np.array([[-1], [-1]], dtype=np.int32) + expected_gt_classes = expected_gt_classes[np.newaxis, ...] + # self.assertAllClose(expected_rois, sampled_rois) + self.assertAllClose(expected_gt_boxes, sampled_gt_boxes) + self.assertAllClose(expected_gt_classes, sampled_gt_classes) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler_large_threshold_append_gt_boxes(self): + # the 2nd roi and 2nd gt box has IOU of 0.923, setting + # positive_threshold to 0.95 to ignore it. + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=2, + append_gt_boxes=True, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + _, sampled_gt_boxes, _, sampled_gt_classes, _ = roi_sampler( + rois, gt_boxes, gt_classes + ) + # the selected gt boxes should be [0, 0, 0, 0], and [10, 10, 15, 15] + # but the 2nd will be encoded to 0. + self.assertAllClose(np.min(ops.convert_to_numpy(sampled_gt_boxes)), 0) + self.assertAllClose(np.max(ops.convert_to_numpy(sampled_gt_boxes)), 0) + # the selected gt classes should be [0, 2 or 10] + self.assertAllLessEqual( + np.max(ops.convert_to_numpy(sampled_gt_classes)), 10 + ) + self.assertAllGreaterEqual( + np.min(ops.convert_to_numpy(sampled_gt_classes)), 0 + ) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_roi_sampler_large_num_sampled_rois(self): + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=200, + append_gt_boxes=True, + ) + rois = np.array( + [ + [0, 0, 5, 5], + [2.5, 2.5, 7.5, 7.5], + [5, 5, 10, 10], + [7.5, 7.5, 12.5, 12.5], + ] + ) + rois = rois[np.newaxis, ...] + # the 3rd box will generate 0 IOUs and not sampled. + gt_boxes = np.array( + [[10, 10, 15, 15], [2.6, 2.6, 7.6, 7.6], [-1, -1, -1, -1]] + ) + gt_boxes = gt_boxes[np.newaxis, ...] + gt_classes = np.array([[2, 10, -1]], dtype=np.int32) + gt_classes = gt_classes[..., np.newaxis] + with self.assertRaisesRegex(ValueError, "must be less than"): + _, _, _ = roi_sampler(rois, gt_boxes, gt_classes) + + @pytest.mark.skipif(not keras_3(), reason="disabling test for Keras 2") + def test_serialization(self): + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + roi_sampler = ROISampler( + roi_bounding_box_format="xyxy", + gt_bounding_box_format="xyxy", + roi_matcher=box_matcher, + positive_fraction=0.5, + num_sampled_rois=200, + append_gt_boxes=True, + ) + sampler_config = roi_sampler.get_config() + new_sampler = ROISampler.from_config(sampler_config) + self.assertAllEqual(new_sampler.roi_matcher.match_values, [-1, 1])