Skip to content

Running ops in specific device-(id) with global distribution setting (multi-gpu/tpu) #21190

@innat

Description

@innat
  1. For a specific reason, I need to run some operations into specific device while keeping other operations runable with global distribution. In torch, we can do .to(device_id) to place operations to specific device. Is it possible in keras yet? Maybe using keras.distribution.initialize, no?

  2. Also, current distribution api only support jax. Any update on supporting tensorflow and torch?

  3. Is there any known bug or limitation in the Keras distributed API when using the JAX backend? I realize this might sound a bit vague, but in one of my setups, training works fine with a single GPU. However, when I switch to multi-GPU or TPU training with jax, things go wrong. Interestingly, using TensorFlow's distribution strategy for multi-GPU training runs without issues. In short, training with JAX across multiple devices seems to corrupt the model weights. While reproducing the issue would take some time, I wanted to check if there are any known issues related to this.

devices = jax.devices("tpu")
devices
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

data_parallel = keras.distribution.DataParallel(devices=devices)
keras.distribution.set_distribution(data_parallel)

Metadata

Metadata

Labels

type:featureThe user is asking for a new feature.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions