Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Convolution may have some precision problem with autotuned cudnn #15638

Open
zixuanweeei opened this issue Jul 23, 2019 · 1 comment
Open

Convolution may have some precision problem with autotuned cudnn #15638

zixuanweeei opened this issue Jul 23, 2019 · 1 comment
Labels
Backend Issues related to the backend of MXNet CUDA

Comments

@zixuanweeei
Copy link
Contributor

Description

The convolution outputs of autotuned and non-autotuned cudnn are inconsistency when grad_req was set to {"x": "null", "w": "null"}.

Environment info

Test script. MXNet was built from source. Built details was listed below.

from __future__ import print_function
from __future__ import division
import numpy as np
import mxnet as mx
import copy
import itertools
from numpy.testing import assert_allclose
from mxnet.test_utils import default_context
import os
import traceback


mx.test_utils.set_default_context(mx.gpu(0))
def conv_gen(kernel, stride, pad, num_filter, no_bias, x_shape, w_shape,
        args, grad, grad_req, autotune=0):
    os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = str(autotune)
    # print("Autotune: ", autotune)
    # Symbols definition
    args_, grad_, grad_req_ = \
            copy.deepcopy(args), copy.deepcopy(grad), copy.deepcopy(grad_req)
    
    x = mx.sym.Variable('x')
    w = mx.sym.Variable('w')
    b = mx.sym.Variable('b') if not no_bias else None
    conv = mx.sym.Convolution(x, w, b, num_filter=num_filter, 
        kernel=kernel, stride=stride, pad=pad, no_bias=no_bias)
    
    dev = default_context()
    exe1 = conv.bind(dev, args_, args_grad=grad_, grad_req=grad_req_)
    exe1.forward(is_train=True)
    exe1.backward(exe1.outputs[0])
    mx.nd.waitall()
    return args_, grad_, exe1.outputs
    

def test_convolution_independent_gradients():
    reqs = ["null", "write", "add"]
    var_names = ["x", "w", "b"]
    # Prepare params shape
    kernel = (5, 5)
    stride = (1, 1)
    pad = (1, 1)
    num_filter = 64
    x_shape = (2, 64, 7, 7)
    w_shape = (64, 64, 5, 5)
    
    for x_req, w_req, b_req in itertools.product(reqs, repeat=3):
        for no_bias in [False, True]:
            # Binding args for conv with possible dependent gradients
            base_args = {
                'x': mx.nd.random.normal(shape=x_shape),
                'w': mx.nd.random.normal(shape=w_shape),
                'b': mx.nd.random.normal(shape=(num_filter, )) if not no_bias else None}
            grad = {
                'x': mx.nd.zeros(shape=x_shape),
                'w': mx.nd.zeros(shape=w_shape),
                'b': mx.nd.zeros(shape=(num_filter, )) if not no_bias else None}
            grad_req = {"x": x_req, "w": w_req, "b": b_req}

            try:
                args0, grad0, out0 = conv_gen(kernel, stride, pad, num_filter, no_bias, x_shape, w_shape,
                    base_args, grad, grad_req, 0)
                args1, grad1, out1 = conv_gen(kernel, stride, pad, num_filter, no_bias, x_shape, w_shape,
                    base_args, grad, grad_req, 1)
                args2, grad2, out2 = conv_gen(kernel, stride, pad, num_filter, no_bias, x_shape, w_shape,
                    base_args, grad, grad_req, 2)
                
                # for var_name in var_names:
                #     if var_name == "b" and no_bias:
                #         continue
                #     assert_allclose(args0[var_name].asnumpy(), args1[var_name].asnumpy(), rtol=1.0e-3, atol=1.0e-3)
                #     assert_allclose(args1[var_name].asnumpy(), args2[var_name].asnumpy(), rtol=1.0e-3, atol=1.0e-3)
                #     assert_allclose(grad0[var_name].asnumpy(), grad1[var_name].asnumpy(), rtol=1.0e-3, atol=1.0e-3)
                #     assert_allclose(grad1[var_name].asnumpy(), grad2[var_name].asnumpy(), rtol=1.0e-3, atol=1.0e-3)
                
                for m0, m1, m2 in zip(out0, out1, out2):
                    assert_allclose(m0.asnumpy(), m1.asnumpy(), atol=1.0e-3, rtol=1.0e-3)
                    assert_allclose(m1.asnumpy(), m2.asnumpy(), atol=1.0e-3, rtol=1.0e-3)
            except:
                print("==========================================================================================")
                print(traceback.print_exc())
                print("x_req: {}, w_req: {}, b_req: {}".format(x_req, w_req, "no_bias" if no_bias else b_req))
                print("==========================================================================================")


if __name__ == "__main__":
    test_convolution_independent_gradients()

Build info

Compiler gcc: gcc version 5.3.1 20160406 (Red Hat 5.3.1-6) (GCC)

MXNet commit hash: 77254f2

Build config:

make -j50 USE_PROFILER=0 USE_CUDA=1 USE_CUDNN=1 USE_MKLDNN=1 USE_BLAS=mkl USE_INTEL_PATH=/opt/intel USE_CUDA_PATH=/path/to/cuda-9.0 USE_CUDNN_PATH=/path/to/cudnn/cudnn-9.0-linux-x64-v7.1.2

Error Message:

Traceback (most recent call last):
  File "test_gpu_case_issue.py", line 77, in test_convolution_independent_gradients
    assert_allclose(m0.asnumpy(), m1.asnumpy(), atol=1.0e-3, rtol=1.0e-3)
  File "/home/zixuanwe/miniconda3/lib/python3.7/site-packages/numpy/testing/nose_tools/utils.py", line 1398, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/home/zixuanwe/miniconda3/lib/python3.7/site-packages/numpy/testing/nose_tools/utils.py", line 781, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

(mismatch 9.0625%)
 x: array([[[[ 2.309166e+01,  1.388979e+01, -2.442778e+01, -1.903630e+01,
          -2.764324e+01],
         [ 3.716447e+00,  4.774844e+01, -5.421930e+01,  4.471700e+01,...
 y: array([[[[ 2.306942e+01,  1.390011e+01, -2.442946e+01, -1.903965e+01,
          -2.763445e+01],
         [ 3.733934e+00,  4.774546e+01, -5.421634e+01,  4.471697e+01,...
None
x_req: null, w_req: null, b_req: null

Steps to reproduce

Build MXNet from source and just run the script above.

@zixuanweeei
Copy link
Contributor Author

@DickJC123 Could you take a look at this? Thanks.

@zachgk zachgk added the Backend Issues related to the backend of MXNet label Jul 23, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Backend Issues related to the backend of MXNet CUDA
Projects
None yet
Development

No branches or pull requests

3 participants