From ea01ad28e11d13e41641bfe4f28d840832806e46 Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Mon, 9 Sep 2024 13:28:13 +0530 Subject: [PATCH 1/3] added validation checks and raised error if an invalid input shape is passed to compute_output_shape func in UnitNormalization Layer --- .../normalization/unit_normalization.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/keras/src/layers/normalization/unit_normalization.py b/keras/src/layers/normalization/unit_normalization.py index 3b0b34f4d80..65720a1df74 100644 --- a/keras/src/layers/normalization/unit_normalization.py +++ b/keras/src/layers/normalization/unit_normalization.py @@ -43,6 +43,45 @@ def call(self, inputs): return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12) def compute_output_shape(self, input_shape): + """ + Compute the output shape of the layer. + + Parameters + ---------- + input_shape + Shape tuple (tuple of integers) or list of shape tuples (one per + output tensor of the layer). Shape tuples can include None for free + dimensions, instead of an integer. + + Returns + ------- + output_shape + Shape of the output of Unit Normalization Layer for + an input of given shape. + + Raises + ------ + ValueError + If an axis is out of bounds for the input shape. + TypeError + If the input shape is not a tuple or a list of tuples. + """ + if isinstance(input_shape, (tuple, list)): + input_shape = input_shape + else: + raise TypeError( + "Invalid input shape type: expected tuple or list. " + f"Received: {type(input_shape)}" + ) + + for axis in self.axis: + if axis >= len(input_shape) or axis < -len(input_shape): + raise ValueError( + f"Axis {axis} is out of bounds for " + "input shape {input_shape}. " + "Ensure axis is within the range of input dimensions." + ) + return input_shape def get_config(self): From e081c7ee10225213cb79e21e48a27f8b4dfe991b Mon Sep 17 00:00:00 2001 From: Sanskar Modi Date: Mon, 9 Sep 2024 13:53:21 +0530 Subject: [PATCH 2/3] updated my change to check if the input is int or an iterable before iterating --- .../normalization/unit_normalization.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/keras/src/layers/normalization/unit_normalization.py b/keras/src/layers/normalization/unit_normalization.py index 65720a1df74..9fdd13e33af 100644 --- a/keras/src/layers/normalization/unit_normalization.py +++ b/keras/src/layers/normalization/unit_normalization.py @@ -49,9 +49,9 @@ def compute_output_shape(self, input_shape): Parameters ---------- input_shape - Shape tuple (tuple of integers) or list of shape tuples (one per - output tensor of the layer). Shape tuples can include None for free - dimensions, instead of an integer. + Shape tuple (tuple of integers) or list of shape tuples + (one per output tensor of the layer). Shape tuples can + include None for free dimensions, instead of an integer. Returns ------- @@ -74,12 +74,19 @@ def compute_output_shape(self, input_shape): f"Received: {type(input_shape)}" ) - for axis in self.axis: + # Ensure axis is always treated as a list + if isinstance(self.axis, int): + axes = [self.axis] + else: + axes = self.axis + + for axis in axes: if axis >= len(input_shape) or axis < -len(input_shape): raise ValueError( f"Axis {axis} is out of bounds for " - "input shape {input_shape}. " - "Ensure axis is within the range of input dimensions." + f"input shape {input_shape}. " + "Ensure axis is within the range of input" + " dimensions." ) return input_shape From a3708041c48e01f9d952931e99435b3b2dafe744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Mon, 9 Sep 2024 10:44:08 -0700 Subject: [PATCH 3/3] Update unit_normalization.py --- .../normalization/unit_normalization.py | 38 +------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/keras/src/layers/normalization/unit_normalization.py b/keras/src/layers/normalization/unit_normalization.py index 9fdd13e33af..be77aa59c30 100644 --- a/keras/src/layers/normalization/unit_normalization.py +++ b/keras/src/layers/normalization/unit_normalization.py @@ -43,37 +43,6 @@ def call(self, inputs): return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12) def compute_output_shape(self, input_shape): - """ - Compute the output shape of the layer. - - Parameters - ---------- - input_shape - Shape tuple (tuple of integers) or list of shape tuples - (one per output tensor of the layer). Shape tuples can - include None for free dimensions, instead of an integer. - - Returns - ------- - output_shape - Shape of the output of Unit Normalization Layer for - an input of given shape. - - Raises - ------ - ValueError - If an axis is out of bounds for the input shape. - TypeError - If the input shape is not a tuple or a list of tuples. - """ - if isinstance(input_shape, (tuple, list)): - input_shape = input_shape - else: - raise TypeError( - "Invalid input shape type: expected tuple or list. " - f"Received: {type(input_shape)}" - ) - # Ensure axis is always treated as a list if isinstance(self.axis, int): axes = [self.axis] @@ -83,12 +52,9 @@ def compute_output_shape(self, input_shape): for axis in axes: if axis >= len(input_shape) or axis < -len(input_shape): raise ValueError( - f"Axis {axis} is out of bounds for " - f"input shape {input_shape}. " - "Ensure axis is within the range of input" - " dimensions." + f"Axis {self.axis} is out of bounds for " + f"input shape {input_shape}." ) - return input_shape def get_config(self):