-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
-
For a specific reason, I need to run some operations into specific device while keeping other operations runable with global distribution. In torch, we can do
.to(device_id)to place operations to specific device. Is it possible in keras yet? Maybe usingkeras.distribution.initialize, no? -
Also, current distribution api only support
jax. Any update on supportingtensorflowandtorch? -
Is there any known bug or limitation in the Keras distributed API when using the JAX backend? I realize this might sound a bit vague, but in one of my setups, training works fine with a single GPU. However, when I switch to multi-GPU or TPU training with jax, things go wrong. Interestingly, using TensorFlow's distribution strategy for multi-GPU training runs without issues. In short, training with JAX across multiple devices seems to corrupt the model weights. While reproducing the issue would take some time, I wanted to check if there are any known issues related to this.
devices = jax.devices("tpu")
devices
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
data_parallel = keras.distribution.DataParallel(devices=devices)
keras.distribution.set_distribution(data_parallel)