You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In short, I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.
importjaxfromjaximportlax, randomimportjax.numpyasjnpimporttime# Directly implement the Conv2DTranspose in JAXdeftoy_model_jax(x, params):
returnlax.conv_transpose(x, params["kernel"], strides=(16, 16), padding="VALID")
# Initialize parameters for the toy modeldefinitialize_params(key):
kernel_shape= (32, 32, 128, 32) # (height, width, in_channels, out_channels)kernel=random.normal(key, kernel_shape)
return {"kernel": kernel}
# Generate random input and paramsstart_time=time.time()
key=random.PRNGKey(42)
x=random.normal(key, (1, 8, 8, 128))
params=initialize_params(key)
end_time=time.time()
print(f"Initialization Run Time: {end_time-start_time:.6f} seconds")
# JIT-compile and time the model runtoy_model_jax_jitted=jax.jit(toy_model_jax)
# Time the model compilationstart_time=time.time()
# Warm-up call (this compiles the function)_=toy_model_jax_jitted(x, params)
end_time=time.time()
print(f"JAX Compilation Time: {end_time-start_time:.6f} seconds")
# Time the model runstart_time=time.time()
o=toy_model_jax_jitted(x, params)
print("input_shape", x.shape)
print("output_shape", o.shape)
end_time=time.time()
print(f"JITted Run Time: {end_time-start_time:.6f} seconds")
This is apparently due to convolution autotuning: some of the algorithms in cudnn are very slow and we try them all during autotuning. Once autotuning has run we will choose a fast algorithm.
It seems in this case the same algorithms are returned by heuristics_mode_a and heuristics_mode_b. So when we deduplicate the algorithms to try during autotuning, we can half the compile time. That still means it is slow, but it is a step in the right direction. There is an idea how to potentially speed it up more by stopping an autotuning attempt if the best known runtime is already exceeded, but that will take a bit longer to implement.
Description
I initially submitted the issue here:
google-deepmind/dm-haiku#724
But then realized it was a jax issue.
In short, I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.
output
For comparison, here is the pytorch:
Google colab notebook:
https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr
What jax/jaxlib version are you using?
Google Colab
Which accelerator(s) are you using?
GPU
Additional system info
Google Colab
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: