diff --git a/examples/tensorflow/scripts/train_imagenet_resnet_hvd.py b/examples/tensorflow/scripts/train_imagenet_resnet_hvd.py index 0184b3873..f1a1ae098 100644 --- a/examples/tensorflow/scripts/train_imagenet_resnet_hvd.py +++ b/examples/tensorflow/scripts/train_imagenet_resnet_hvd.py @@ -176,7 +176,7 @@ def batch_norm(self, inputs, **kwargs): ) def spatial_average2d(self, inputs): - shape = inpusmd.get_shape().as_list() + shape = inputs.get_shape().as_list() if self.data_format == "channels_last": n, h, w, c = shape else: