diff --git a/smdebug/tensorflow/utils.py b/smdebug/tensorflow/utils.py index 90dfa5d29..863a5f862 100644 --- a/smdebug/tensorflow/utils.py +++ b/smdebug/tensorflow/utils.py @@ -339,7 +339,7 @@ def get_layer_call_fn(layer: tf.keras.layers.Layer) -> Callable[[tf.Tensor], tf. def call(inputs, *args, **kwargs) -> tf.Tensor: layer_input = inputs - layer_output = old_call_fn(inputs) + layer_output = old_call_fn(inputs, *args, **kwargs) for hook in layer._hooks: hook_result = hook(inputs, layer_input=layer_input, layer_output=layer_output) if hook_result is not None: diff --git a/tests/tensorflow2/test_model_subclassing.py b/tests/tensorflow2/test_model_subclassing.py index d90b931e2..c9f92999b 100644 --- a/tests/tensorflow2/test_model_subclassing.py +++ b/tests/tensorflow2/test_model_subclassing.py @@ -13,6 +13,16 @@ def __init__(self): self.conv1 = Conv2D( 32, 3, activation="relu", kernel_initializer=tf.keras.initializers.GlorotNormal(seed=12) ) + self.original_call = self.conv1.call + + def new_call(inputs, *args, **kwargs): + # Since we use layer wrapper we need to assert if these parameters + # are actually being passed into the original call fn + assert kwargs["input_one"] == 1 + kwargs.pop("input_one") + return self.original_call(inputs, *args, **kwargs) + + self.conv1.call = new_call self.conv0 = Conv2D( 32, 3, activation="relu", kernel_initializer=tf.keras.initializers.GlorotNormal(seed=12) ) @@ -26,8 +36,7 @@ def __init__(self): def first(self, x): with tf.name_scope("first"): tf.print("mymodel.first") - x = self.conv1(x) - # x = self.bn(x) + x = self.conv1(x, input_one=1) return self.flatten(x) def second(self, x):