Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142)
This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
- Loading branch information