diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index ecd63f60f2b7..0a6654206006 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -1132,14 +1132,12 @@ def unique_shape_func(attrs, inputs, _):
 @script
 def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim):
     ndim = data_shape.shape[0]
-    mdim = gather_dim
     # using mdim = indices_shape[0] wouldn't work because a rank cannot
     # depend on a runtime shape dimension of indices tensor, even if the
     # dimension is always a known, fixed value. As a workaround, we assume that
     # the fixed gather dimension (the size of an indexing tuple) is recorded
     # in `gather_nd` op attribute.
-    err_msg = "The recorded gather dimension and the actual dimension are different"
-    assert mdim == indices_shape[0], err_msg
+    mdim = gather_dim
     kdim = indices_shape.shape[0] - 1
     out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
     for i in range(1, kdim + 1):
@@ -1154,7 +1152,7 @@ def gather_nd_shape_func(attrs, inputs, _):
     """
     Shape func for ghater_nd operator.
     """
-    batch_dims = get_const_int(attrs.batch_dimss)
+    batch_dims = get_const_int(attrs.batch_dims)
     gather_dim = get_const_int(attrs.gather_dim)
     assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd"
     return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))]
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index fd6d7a9aeb14..07955943e341 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -26,6 +26,7 @@
 from tvm.error import TVMError
 from tvm.relay import create_executor, transform
 from tvm.relay.testing import check_grad, run_infer_type
+from utils import ref_funcs
 
 
 def test_zeros_ones():
@@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
         else:
             y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")
 
-        def gather_nd_batch_dims_1_ref(data, indices):
-            res = []
-            for i, row in enumerate(data):
-                indices_tuple = tuple(indices[:, i])  # the indices for the i-th batch
-                res.append(row[indices_tuple])
-            # stack on the batch dim
-            return np.stack(res, 0)
-
-        if batch_dims > 1:
-            x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
-            y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])
-
-            ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)
-
-            out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
-            ref_res = np.reshape(ref_res, out_shape)
-        elif batch_dims == 1:
-            ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
-        else:
-            ref_res = x_data[tuple(y_data)]
+        ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)
 
         for target, dev in tvm.testing.enabled_targets():
             for kind in ["graph", "debug"]:
diff --git a/tests/python/relay/utils/ref_funcs.py b/tests/python/relay/utils/ref_funcs.py
new file mode 100644
index 000000000000..924805b2295e
--- /dev/null
+++ b/tests/python/relay/utils/ref_funcs.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+
+
+def gather_nd(data_np, indices_np, batch_dims=0):
+    """gather_nd implemented using numpy"""
+    data_shape = data_np.shape
+    indices_shape = indices_np.shape
+
+    def gather_nd_batch_dims_1_ref(data, indices):
+        res = []
+        for i, row in enumerate(data):
+            indices_tuple = tuple(indices[:, i])  # the indices for the i-th batch
+            res.append(row[indices_tuple])
+        # stack on the batch dim
+        return np.stack(res, 0)
+
+    if batch_dims > 1:
+        data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:])
+        indices_np_reshape = np.reshape(
+            indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :]
+        )
+
+        ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape)
+
+        out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:]
+        ref_res = np.reshape(ref_res, out_shape)
+    elif batch_dims == 1:
+        ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np)
+    else:
+        ref_res = data_np[tuple(indices_np)]
+
+    return ref_res