We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I found that conv2d_transpose op fails when kernel size is 2x2 and strides are (2,2). Errors
tvm/src/pass/loop_partition.cc:544: Cannot prove: ((((floordiv((((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) + 63), 64) - 1) - (((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) - select((-63 <= ((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2)))), (floordiv((((40 - (floordiv((dh + 1), 2)*8))*(5 - floordiv((dw + 1), 2))) + 63), 64)*63), 0))) + 1) >= 0), when generating the post doubt loop File "/root/workplace/tvm/src/pass/split_host_device.cc", line 135 TVMError: Check failed: !use_count_.count(v): variable dh has been used before definition! During handling of the above exception, another exception occurred:
To reproduce the error
import tvm from tvm import relay import tensorflow as tf input_tensor = "input_1" # NHWC input_shape=(1,16,16,8) x = tf.compat.v1.placeholder(tf.float32, shape=input_shape, name=input_tensor) # HWOI w2 = tf.compat.v1.placeholder(tf.float32, shape=(2,2,3,8)) out_shape = tf.compat.v1.placeholder(tf.int32, shape=(4)) deconv = tf.compat.v1.nn.conv2d_transpose(x, w2, out_shape, (1,2,2,1), padding='VALID') sess = tf.compat.v1.Session() graph_def = sess.graph_def mod, params = relay.frontend.from_tensorflow(graph_def, layout='NCHW', shape={input_tensor: input_shape}) target = tvm.target.cuda() from tvm.autotvm.measure.measure_methods import set_cuda_target_arch set_cuda_target_arch('sm_70') with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params)
I tried other targets - llvm and arm_cpu - they are working fine. Only cuda fails.
Related to PR 4243
Discussions: https://discuss.tvm.ai/t/conv2d-transpose-kernel-2x2-strides-2-2-fails-for-cuda-cannot-prove/5020
https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164
The text was updated successfully, but these errors were encountered:
Workaround PR #4472 @vinx13 @tqchen @merrymercy @Huyuwei @optima2005
Sorry, something went wrong.
No branches or pull requests
I found that conv2d_transpose op fails when kernel size is 2x2 and strides are (2,2).
Errors
To reproduce the error
I tried other targets - llvm and arm_cpu - they are working fine. Only cuda fails.
Related to PR 4243
Discussions:
https://discuss.tvm.ai/t/conv2d-transpose-kernel-2x2-strides-2-2-fails-for-cuda-cannot-prove/5020
https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164
The text was updated successfully, but these errors were encountered: