Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569290058
  • Loading branch information
tensorflower-gardener committed Sep 28, 2023
1 parent 2800390 commit a2823c7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
7 changes: 7 additions & 0 deletions official/vision/configs/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class Parser(hyperparams.Config):
rpn_batch_size_per_im: int = 256
rpn_fg_fraction: float = 0.5
mask_crop_size: int = 112
pad: bool = True # Only support `pad = True`.

def __post_init__(self, *args, **kwargs):
"""Validates the configuration."""
if not self.pad:
raise ValueError('`maskrcnn.Parser` only supports `pad = True`.')
super().__post_init__(*args, **kwargs)


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions official/vision/configs/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Parser(hyperparams.Config):
max_num_instances: int = 100
# Can choose AutoAugment and RandAugment.
aug_type: Optional[common.Augmentation] = None
pad: bool = True

# Keep for backward compatibility. Not used.
aug_policy: Optional[str] = None
Expand Down
51 changes: 40 additions & 11 deletions official/vision/serving/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

"""Detection input and model functions for serving/inference."""

import math
from typing import Mapping, Tuple

from absl import logging
import tensorflow as tf

from official.core import config_definitions as cfg
from official.vision import configs
from official.vision.modeling import factory
from official.vision.ops import anchor
Expand All @@ -30,6 +32,34 @@
class DetectionModule(export_base.ExportModule):
"""Detection Module."""

def __init__(
self,
params: cfg.ExperimentConfig,
*,
input_image_size: list[int],
**kwargs,
):
"""Initializes a detection module for export.
Args:
params: Experiment params.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
**kwargs: All other kwargs are passed to `export_base.ExportModule`; see
the documentation on `export_base.ExportModule` for valid arguments.
"""
if params.task.train_data.parser.pad:
self._padded_size = preprocess_ops.compute_padded_size(
input_image_size, 2**params.task.model.max_level
)
else:
self._padded_size = input_image_size
super().__init__(
params=params,
input_image_size=input_image_size,
**kwargs,
)

def _build_model(self):

nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
Expand All @@ -40,8 +70,8 @@ def _build_model(self):
'does not support with dynamic batch size.', nms_version)
self.params.task.model.detection_generator.nms_version = 'batched'

input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
input_specs = tf.keras.layers.InputSpec(shape=[
self._batch_size, *self._padded_size, 3])

if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
model = factory.build_maskrcnn(
Expand All @@ -64,23 +94,21 @@ def _build_anchor_boxes(self):
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(
image_size=(self._input_image_size[0], self._input_image_size[1]))
return input_anchor(image_size=self._padded_size)

def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
model_params = self.params.task.model
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)

image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=preprocess_ops.compute_padded_size(
self._input_image_size, 2**model_params.max_level),
padded_size=self._padded_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
aug_scale_max=1.0,
)
anchor_boxes = self._build_anchor_boxes()

return image, anchor_boxes, image_info
Expand Down Expand Up @@ -128,7 +156,7 @@ def preprocess(
images = tf.cast(images, dtype=tf.float32)

# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
images_spec = tf.TensorSpec(shape=self._input_image_size + [3],
images_spec = tf.TensorSpec(shape=self._padded_size + [3],
dtype=tf.float32)

num_anchors = model_params.anchor.num_scales * len(
Expand All @@ -137,8 +165,9 @@ def preprocess(
for level in range(model_params.min_level, model_params.max_level + 1):
anchor_level_spec = tf.TensorSpec(
shape=[
self._input_image_size[0] // 2**level,
self._input_image_size[1] // 2**level, num_anchors
math.ceil(self._padded_size[0] / 2**level),
math.ceil(self._padded_size[1] / 2**level),
num_anchors,
],
dtype=tf.float32)
anchor_shapes.append((str(level), anchor_level_spec))
Expand Down

0 comments on commit a2823c7

Please sign in to comment.