You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Not sure if this is a jax thing or dm-haiku... but recently I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.
here is an example:
importhaikuashkimportjaxfromjaximportrandomimporttimedeftoy_model(x):
x=hk.Conv2DTranspose(32, 32, stride=16, padding="VALID")(x)
returnx# Transform the model to be JAX-compatibletoy_model_init=hk.transform(toy_model).inittoy_model_apply=hk.transform(toy_model).apply# Generate random input and paramskey=random.PRNGKey(42)
x=random.normal(key, (1, 8, 8, 128))
# Time the model initializationstart_time=time.time()
params=toy_model_init(key, x)
end_time=time.time()
print(f"initialization Time: {end_time-start_time:.6f} seconds")
# Time the model compilationstart_time=time.time()
compiled_apply=jax.jit(toy_model_apply)
# Warm-up call (this compiles the function)_=compiled_apply(params, None, x)
end_time=time.time()
print(f"Compilation Time: {end_time-start_time:.6f} seconds")
# Time the model runstart_time=time.time()
o=compiled_apply(params, None, x)
print("input_shape",x.shape)
print("output_shape",o.shape)
end_time=time.time()
print(f"Run Time: {end_time-start_time:.6f} seconds")
Not sure if this is a jax thing or dm-haiku... but recently I've been trying to use Conv2DTranspose in my model, and even for very simple case... it takes forever to compile.
here is an example:
output
for comparison, in pytorch:
Google colab notebook replicating the test:
https://colab.research.google.com/drive/15YkOuK0EjqZdBNaXpF2wpYexGqtjZjLr
The text was updated successfully, but these errors were encountered: