Skip to content

Commit

Permalink
adding deeplab_v3_plus_presets (#2051)
Browse files Browse the repository at this point in the history
* adding deeplab_v3_plus_presets

* update formatting

* fixed error

* reformatted

* fix errors

* added num_classses arg

* add input_shape arg

* update test

* update input_shape

* updated preset

* update test

* move update to from_preset to task.py

* reformatted code

* updated doc string

* Update assert

* Update deeplab_v3_plus_test.py test

* Update task.py

* Update deeplab_v3_plus_test.py

* Update deeplab_v3_plus_test.py
  • Loading branch information
divyashreepathihalli authored Sep 20, 2023
1 parent 9363c80 commit 05263cf
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from keras_cv.models.backbones.backbone_presets import (
backbone_presets_with_weights,
)
from keras_cv.models.segmentation.deeplab_v3_plus.deeplab_v3_plus_presets import ( # noqa: E501
deeplab_v3_plus_presets,
)
from keras_cv.models.task import Task
from keras_cv.utils.python_utils import classproperty
from keras_cv.utils.train import get_feature_extractor
Expand Down Expand Up @@ -229,13 +232,15 @@ def from_config(cls, config):
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)
return copy.deepcopy({**backbone_presets, **deeplab_v3_plus_presets})

@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include
weights."""
return copy.deepcopy(backbone_presets_with_weights)
return copy.deepcopy(
{**backbone_presets_with_weights, **deeplab_v3_plus_presets}
)

@classproperty
def backbone_presets(cls):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DeepLabV3Plus presets."""

from keras_cv.models.backbones.resnet_v2 import resnet_v2_backbone_presets

deeplab_v3_plus_presets = {
"deeplab_v3_plus_resnet50_pascalvoc": {
"metadata": {
"description": (
"DeeplabV3Plus with a ResNet50 v2 backbone. "
"Trained on PascalVOC 2012 Semantic segmentation task, which "
"consists of 20 classes and one background class. This model "
"achieves a final categorical accuracy of 89.34% and mIoU of "
"0.6391 on evaluation dataset."
),
"params": 39191488,
"official_name": "DeepLabV3Plus",
"path": "deeplab_v3_plus",
},
"config": {
"backbone": resnet_v2_backbone_presets.backbone_presets[
"resnet50_v2_imagenet"
],
# 21 used as an implicit background class marginally improves
# performance.
"num_classes": 21,
},
"weights_url": "https://storage.googleapis.com/keras-cv/models/deeplab_v3_plus/voc/deeplabv3plus_resenet50_pascal_voc.weights.h5", # noqa: E501
"weights_hash": "9681410a57bea2bc5cb7d79a1802d872ac263faab749cfe5ffdae6d6c3082041", # noqa: E501
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def test_weights_change(self):
self.assertNotAllEqual(w1, w2)
self.assertFalse(ops.any(ops.isnan(w2)))

@pytest.mark.large
def test_with_model_preset_forward_pass(self):
model = DeepLabV3Plus.from_preset(
"deeplab_v3_plus_resnet50_pascalvoc",
num_classes=21,
input_shape=[512, 512, 3],
)
image = np.ones((1, 512, 512, 3))
output = ops.expand_dims(ops.argmax(model(image), axis=-1), axis=-1)
expected_output = np.zeros((1, 512, 512, 1))
self.assertAllClose(output, expected_output)

@parameterized.named_parameters(
("tf_format", "tf", "model"),
("keras_format", "keras_v3", "model.keras"),
Expand Down
6 changes: 6 additions & 0 deletions keras_cv/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def from_preset(
cls,
preset,
load_weights=None,
input_shape=None,
**kwargs,
):
"""Instantiate {{model_name}} model from preset config and weights.
Expand All @@ -93,6 +94,9 @@ def from_preset(
load_weights: Whether to load pre-trained weights into model.
Defaults to `None`, which follows whether the preset has
pretrained weights available.
input_shape : input shape that will be passed to backbone
initialization, Defaults to `None`.If `None`, the preset
value will be used.
Examples:
```python
Expand Down Expand Up @@ -138,6 +142,8 @@ def from_preset(

# Otherwise must be one of class presets
config = metadata["config"]
if input_shape is not None:
config["backbone"]["config"]["input_shape"] = input_shape
model = cls.from_config({**config, **kwargs})

if preset not in cls.presets_with_weights or load_weights is False:
Expand Down

0 comments on commit 05263cf

Please sign in to comment.