Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Oct 2, 2024
1 parent 1071028 commit dfe8af7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 27 deletions.
3 changes: 0 additions & 3 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
RandomShear,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
3 changes: 0 additions & 3 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_shear import (
RandomShear,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ class RandomRotation(BaseImagePreprocessingLayer):
seed: Integer. Used to create a random seed.
fill_value: a float represents the value to be filled outside
the boundaries when `fill_mode="constant"`.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
"""

_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

_SUPPORTED_FILL_MODE = ("reflect", "wrap", "constant", "nearest")
_SUPPORTED_INTERPOLATION = ("nearest", "bilinear")

Expand All @@ -82,14 +86,12 @@ def __init__(
interpolation="bilinear",
seed=None,
fill_value=0.0,
value_range=(0, 255),
data_format=None,
**kwargs,
):
super().__init__(factor=factor, data_format=data_format, **kwargs)
self.seed = seed
self.generator = SeedGenerator(seed)
self._set_value_range(value_range)
self.fill_mode = fill_mode
self.interpolation = interpolation
self.fill_value = fill_value
Expand All @@ -106,19 +108,6 @@ def __init__(
f"{self._SUPPORTED_INTERPOLATION}."
)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self.value_range_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def transform_images(self, images, transformation, training=True):
images = self.backend.cast(images, self.compute_dtype)
if training:
Expand All @@ -138,7 +127,21 @@ def transform_labels(self, labels, transformation, training=True):
def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
raise NotImplementedError
boxes = bounding_boxes["boxes"]
shape = self.backend.shape(boxes)
ones = self.backend.ones((shape[0], shape[1], 1, 1))
homogeneous_boxes = self.backend.concatenate([boxes, ones], axis=2)
transformed_boxes = self.backend.matmul(
transformation["rotation_matrix"], homogeneous_boxes
)
# Convert back to xyxy format
transformed_boxes = (
transformed_boxes[:, :, :2, :] / transformed_boxes[:, :, 2:3, :]
)
transformed_boxes = self.backend.reshape(
transformed_boxes, (shape[0], shape[1], 4)
)
return {"boxes": transformed_boxes, "labels": bounding_boxes["labels"]}

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down Expand Up @@ -233,7 +236,6 @@ def compute_output_shape(self, input_shape):
def get_config(self):
config = {
"factor": self.factor,
"value_range": self.value_range,
"data_format": self.data_format,
"fill_mode": self.fill_mode,
"fill_value": self.fill_value,
Expand Down

0 comments on commit dfe8af7

Please sign in to comment.