Skip to content

Commit

Permalink
Refactored transformarion matrix creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastian-sz committed Mar 18, 2023
1 parent 0be4811 commit 54b5e21
Showing 1 changed file with 52 additions and 21 deletions.
73 changes: 52 additions & 21 deletions keras_cv/layers/preprocessing/random_shear.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,23 +143,9 @@ def augment_ragged_image(self, image, transformation, **kwargs):

def augment_images(self, images, transformations, **kwargs):
x, y = transformations["shear_x"], transformations["shear_y"]
batch_size = tf.shape(images)[0]

base_transforms = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
base_transforms = tf.repeat(
[base_transforms], repeats=batch_size, axis=0
)
if x is not None:
# insert x into the 2nd column of base_transforms
# aka stack N vectors of [1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
transforms_x = tf.concat(
[
tf.expand_dims(base_transforms[:, 0], axis=-1),
x,
base_transforms[:, 2:],
],
axis=-1,
)
transforms_x = self._build_shear_x_transform_matrix(x)
images = preprocessing.transform(
images=images,
transforms=transforms_x,
Expand All @@ -169,12 +155,7 @@ def augment_images(self, images, transformations, **kwargs):
)

if y is not None:
# insert y into the 4th column of base_transforms
# aka stack N vectors of [1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0]
transforms_y = tf.concat(
[base_transforms[:, :3], y, base_transforms[:, 4:]], axis=-1
)

transforms_y = self._build_shear_y_transform_matrix(y)
images = preprocessing.transform(
images=images,
transforms=transforms_y,
Expand All @@ -185,6 +166,56 @@ def augment_images(self, images, transformations, **kwargs):

return images

@staticmethod
def _build_shear_x_transform_matrix(shear_x):
"""Build transform matrix for horizontal shear.
The transform matrix looks like:
(1, x, 0)
(0, 1, 0)
(0, 0, 1)
where the last entry is implicit.
We flatten the matrix to `[1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]` for
use with ImageProjectiveTransformV3.
"""
batch_size = tf.shape(shear_x)[0]
return tf.concat(
values=[
tf.ones((batch_size, 1), tf.float32),
shear_x,
tf.zeros((batch_size, 2), tf.float32),
tf.ones((batch_size, 1), tf.float32),
tf.zeros((batch_size, 3), tf.float32),
],
axis=1,
)

@staticmethod
def _build_shear_y_transform_matrix(shear_y):
"""Build transform matrix for vertical shear.
The transform matrix looks like:
(1, 0, 0)
(y, 1, 0)
(0, 0, 1)
where the last entry is implicit.
We flatten the matrix to `[1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0]` for
use ImageProjectiveTransformV3.
"""
batch_size = tf.shape(shear_y)[0]
return tf.concat(
values=[
tf.ones((batch_size, 1), tf.float32),
tf.zeros((batch_size, 2), tf.float32),
shear_y,
tf.ones((batch_size, 1), tf.float32),
tf.zeros((batch_size, 3), tf.float32),
],
axis=1,
)

def augment_labels(self, labels, transformations, **kwargs):
return labels

Expand Down

0 comments on commit 54b5e21

Please sign in to comment.