diff --git a/keras_cv/layers/preprocessing/random_shear.py b/keras_cv/layers/preprocessing/random_shear.py index 9567cb65ca..0696f10ee0 100644 --- a/keras_cv/layers/preprocessing/random_shear.py +++ b/keras_cv/layers/preprocessing/random_shear.py @@ -143,23 +143,9 @@ def augment_ragged_image(self, image, transformation, **kwargs): def augment_images(self, images, transformations, **kwargs): x, y = transformations["shear_x"], transformations["shear_y"] - batch_size = tf.shape(images)[0] - base_transforms = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] - base_transforms = tf.repeat( - [base_transforms], repeats=batch_size, axis=0 - ) if x is not None: - # insert x into the 2nd column of base_transforms - # aka stack N vectors of [1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] - transforms_x = tf.concat( - [ - tf.expand_dims(base_transforms[:, 0], axis=-1), - x, - base_transforms[:, 2:], - ], - axis=-1, - ) + transforms_x = self._build_shear_x_transform_matrix(x) images = preprocessing.transform( images=images, transforms=transforms_x, @@ -169,12 +155,7 @@ def augment_images(self, images, transformations, **kwargs): ) if y is not None: - # insert y into the 4th column of base_transforms - # aka stack N vectors of [1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0] - transforms_y = tf.concat( - [base_transforms[:, :3], y, base_transforms[:, 4:]], axis=-1 - ) - + transforms_y = self._build_shear_y_transform_matrix(y) images = preprocessing.transform( images=images, transforms=transforms_y, @@ -185,6 +166,56 @@ def augment_images(self, images, transformations, **kwargs): return images + @staticmethod + def _build_shear_x_transform_matrix(shear_x): + """Build transform matrix for horizontal shear. + + The transform matrix looks like: + (1, x, 0) + (0, 1, 0) + (0, 0, 1) + where the last entry is implicit. + + We flatten the matrix to `[1.0, x, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]` for + use with ImageProjectiveTransformV3. + """ + batch_size = tf.shape(shear_x)[0] + return tf.concat( + values=[ + tf.ones((batch_size, 1), tf.float32), + shear_x, + tf.zeros((batch_size, 2), tf.float32), + tf.ones((batch_size, 1), tf.float32), + tf.zeros((batch_size, 3), tf.float32), + ], + axis=1, + ) + + @staticmethod + def _build_shear_y_transform_matrix(shear_y): + """Build transform matrix for vertical shear. + + The transform matrix looks like: + (1, 0, 0) + (y, 1, 0) + (0, 0, 1) + where the last entry is implicit. + + We flatten the matrix to `[1.0, 0.0, 0.0, y, 1.0, 0.0, 0.0, 0.0]` for + use ImageProjectiveTransformV3. + """ + batch_size = tf.shape(shear_y)[0] + return tf.concat( + values=[ + tf.ones((batch_size, 1), tf.float32), + tf.zeros((batch_size, 2), tf.float32), + shear_y, + tf.ones((batch_size, 1), tf.float32), + tf.zeros((batch_size, 3), tf.float32), + ], + axis=1, + ) + def augment_labels(self, labels, transformations, **kwargs): return labels