From be479aed17f67f3ea27302aea53b76883d8f59b5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 13 Feb 2025 16:31:42 -0600 Subject: [PATCH 01/14] feat: use parameter shardings from XLA --- .github/workflows/CI.yml | 1 + src/Compiler.jl | 42 +++++++++++++++++++++++-------- src/ConcreteRArray.jl | 18 ++++++------- src/xla/Buffer.jl | 11 ++++++++ src/xla/LoadedExecutable.jl | 49 ++++++++++++++++++++---------------- src/xla/Sharding.jl | 50 ++++++++++++++++++++++++------------- test/sharding.jl | 13 ++++++---- 7 files changed, 118 insertions(+), 66 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 3751c30566..e4b04db5c6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -157,6 +157,7 @@ jobs: env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager REACTANT_TEST_GROUP: ${{ matrix.test_group }} + XLA_FLAGS: "--xla_force_host_platform_device_count=8" - uses: julia-actions/julia-processcoverage@v1 if: steps.run_tests.outcome == 'success' - uses: codecov/codecov-action@v5 diff --git a/src/Compiler.jl b/src/Compiler.jl index adad293643..c7c0157766 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -986,7 +986,13 @@ The name is due to its similarity to the `flatten` function in `jax.tree_util.re The _linearized arguments_ do not directly refer to the are the arguments that have been flattened into a single list. """ function codegen_flatten!( - linear_args, seen_args, result_stores, is_sharded::Bool, mesh, client + linear_args, + seen_args, + result_stores, + is_sharded::Bool, + mesh, + linear_parameter_shardings, + client, ) flatten_names = Symbol[] flatten_code = Expr[] @@ -1019,34 +1025,38 @@ function codegen_flatten!( for p in path[3:end] flatcode = :(traced_getfield($flatcode, $(Meta.quot(p)))) end - push!(flatten_code, :($usbuf = $flatcode.data)) if is_sharded carg = inv_seen_args[arg] if Reactant.Sharding.is_sharded(carg) + push!(flatten_code, :($usbuf = $flatcode.data)) for j in 1:length(mesh) sbuf = Symbol(:sbuf_, i, "_", j) push!(flatten_names, sbuf) push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j)))) end else - # Warn here first and then replicate the input across all devices on the - # mesh - @warn "Input $carg is not sharded, replicating across all devices. It \ - is recommended to replicate the input across all devices on the \ - mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1 - buf = Symbol(:buf_, i) - push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf)))) + push!(flatten_code, :($usbuf = $flatcode)) + device_to_array_slices, _ = XLA.compute_array_indices_and_partition_spec( + linear_parameter_shardings[i], size(carg), mesh + ) for j in 1:length(mesh) + buf = Symbol(:buf_, i, :_, j) device_id = mesh.device_ids[j] + slice = device_to_array_slices[j] + push!( + flatten_code, + :($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))), + ) device_ordinal = XLA.device_ordinal(client, device_id) - sbuf = Symbol(:sbuf_, i, "_", j) + sbuf = Symbol(:sbuf_, i, :_, j) device = XLA.ClientGetAddressableDevice(client, device_ordinal) push!(flatten_names, sbuf) push!(flatten_code, :($sbuf = XLA.CopyBufferToDevice($buf, $device))) end end else + push!(flatten_code, :($usbuf = $flatcode.data)) sbuf = Symbol(:sbuf_, i) push!(flatten_names, sbuf) if arg isa TracedRArray || arg isa TracedRNumber @@ -1401,7 +1411,8 @@ function compile_xla(f, args; client=nothing, kwargs...) client, device, mod; - num_results=length(mlir_fn_res.linear_results), + num_outputs=length(mlir_fn_res.linear_results), + num_parameters=length(mlir_fn_res.linear_args), mlir_fn_res.is_sharded, mesh_ids, ) @@ -1428,6 +1439,14 @@ function compile(f, args; sync=false, kwargs...) result_stores = Dict{Tuple,Symbol}() path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing + linear_parameter_shardings = if mlir_fn_res.is_sharded + # The sharding info here is exclusively used for the parameters that weren't sharded + # before hand + XLA.get_parameter_shardings(exec) + else + nothing + end + # generate Julia `Thunk` code flatten_arg_names, flatten_code = codegen_flatten!( linear_args, @@ -1435,6 +1454,7 @@ function compile(f, args; sync=false, kwargs...) result_stores, mlir_fn_res.is_sharded, mlir_fn_res.sharding_mesh, + linear_parameter_shardings, client, ) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index e5bc30e1ce..4ed43bed21 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -44,24 +44,20 @@ Base.isempty(x::WrappedConcreteRArray) = isempty(ancestor(x)) function Base.convert(::Type{<:Array}, X::ConcreteRArray{T,N}) where {T,N} data = Array{T,N}(undef, size(X)...) - XLA.await(X) if Sharding.is_sharded(X) - # TODO: We can we much more efficient here and only move data from the minimal - # slices that populates the array. + completed = Set{eltype(X.sharding.device_to_array_slices)}() for idx in 1:length(X.data) - buffer = X.data[idx].buffer - # We can't use a pointer to a subarray since BufferToHost expects the data to - # be contiguous. slice = X.sharding.device_to_array_slices[idx] - data_slice = data[slice...] - GC.@preserve data_slice buffer begin - XLA.BufferToHost(buffer, pointer(data_slice)) + if slice ∉ completed + push!(completed, slice) + else + continue end - data[slice...] = data_slice + data[slice...] = convert(Array{T}, X.data[idx]) end else - buf = only(X.data).buffer + buf = XLA.synced_buffer(only(X.data)) GC.@preserve data buf begin XLA.BufferToHost(buf, pointer(data)) end diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52a0718655..8f6bd4f033 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -70,6 +70,12 @@ function ArrayFromHostBuffer(client::Client, array::Array{T,N}, device) where {T return Buffer(buffer) end +function Base.convert(::Type{<:Array{T}}, buffer::Buffer) where {T} + arr = zeros(T, reverse(size(buffer))...) + BufferToHost(buffer, arr) + return arr +end + function BufferToHost(buffer::Buffer, data) GC.@preserve buffer begin @ccall MLIR.API.mlir_c.BufferToHost( @@ -97,6 +103,11 @@ end const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing) +function Base.convert(::Type{<:Array{T}}, buffer::AsyncBuffer) where {T} + await(buffer) + return convert(Array{T}, buffer.buffer) +end + for op in (:(Base.ndims), :(Base.size), :device, :client) @eval $op(buffer::AsyncBuffer) = $op(buffer.buffer) end diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 7ec4801ff7..5bd1321e9b 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -4,12 +4,15 @@ end mutable struct LoadedExecutable exec::Ptr{Cvoid} - num_results::Int64 + num_outputs::Int64 + num_parameters::Int64 is_sharded::Bool - function LoadedExecutable(exec::Ptr{Cvoid}, num_results::Int64, is_sharded::Bool) + function LoadedExecutable( + exec::Ptr{Cvoid}, num_outputs::Int64, num_parameters::Int64, is_sharded::Bool + ) @assert exec != C_NULL - return finalizer(free_exec, new(exec, num_results, is_sharded)) + return finalizer(free_exec, new(exec, num_outputs, num_parameters, is_sharded)) end end @@ -213,7 +216,8 @@ function Compile( is_sharded::Bool=false, mesh_ids::Vector{Int64}=Int64[], # mesh_shape::Vector{Int64}=Int64[], - num_results::Int64, + num_outputs::Int64, + num_parameters::Int64, ) device_id = is_sharded ? Int64(-1) : Int64(device_ordinal(client, device)) mesh_ids = Int64.(device_ordinal.((client,), mesh_ids)) @@ -223,31 +227,34 @@ function Compile( mod.module_::MLIR.API.MlirModule, device_id::Clong, is_sharded::Bool, - # mesh_shape::Ptr{Clong}, - # length(mesh_shape)::Clong, mesh_ids::Ptr{Clong}, length(mesh_ids)::Clong, CUDA_DATA_DIR[]::Cstring, )::Ptr{Cvoid} end - return LoadedExecutable(exec, num_results, is_sharded) + return LoadedExecutable(exec, num_outputs, num_parameters, is_sharded) end -function get_output_shardings(exec::LoadedExecutable) - exec.is_sharded || return OpSharding[] +for (jlop, xlaop, field) in ( + (:get_output_shardings, :PjRtLoadedExecutableGetOuputShardings, :num_outputs), + (:get_parameter_shardings, :PjRtLoadedExecutableGetParameterShardings, :num_parameters), +) + @eval function $(jlop)(exec::LoadedExecutable) + exec.is_sharded || return OpSharding[] - jl_op_shardings = [Ref{JLOpSharding}() for _ in 1:(exec.num_results)] - jl_op_shardings_ptr = [ - Base.unsafe_convert(Ptr{JLOpSharding}, sharding) for sharding in jl_op_shardings - ] + jl_op_shardings = [Ref{JLOpSharding}() for _ in 1:(exec.$(field))] + jl_op_shardings_ptr = [ + Base.unsafe_convert(Ptr{JLOpSharding}, sharding) for sharding in jl_op_shardings + ] - GC.@preserve jl_op_shardings begin - @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetOuputShardings( - exec.exec::Ptr{Cvoid}, - jl_op_shardings_ptr::Ptr{Ptr{JLOpSharding}}, - exec.num_results::Int32, - )::Cvoid - end + GC.@preserve jl_op_shardings begin + @ccall MLIR.API.mlir_c.$(xlaop)( + exec.exec::Ptr{Cvoid}, + jl_op_shardings_ptr::Ptr{Ptr{JLOpSharding}}, + exec.$(field)::Int32, + )::Cvoid + end - return map(OpSharding ∘ getindex, jl_op_shardings) + return map(OpSharding ∘ getindex, jl_op_shardings) + end end diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 1dc20c4183..7f5f6a175b 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -77,14 +77,14 @@ function OpSharding(sharding::JLOpSharding) return OpSharding( int_to_op_sharding_type(sharding.type), - reverse(tile_dimensions), + tile_dimensions, layout_minor_to_major, sharding.replicate_on_last_tile_dim, - reverse(last_tile_dims), - reverse(tile_assignment_dimensions), - reverse(tile_assignment_devices), - reverse(iota_reshape_dims), - reverse(iota_transpose_perm), + last_tile_dims, + tile_assignment_dimensions, + tile_assignment_devices, + iota_reshape_dims, + iota_transpose_perm, sharding.is_shard_group, sharding.shard_group_id, int_to_shard_group_type(sharding.shard_group_type), @@ -153,21 +153,29 @@ function compute_array_indices_and_partition_spec( $(mesh.device_ids)" @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ Open an issue with an MWE for this case." - @assert !sharding.replicate_on_last_tile_dim "Replication on the last tile \ - dimension is not supported yet! Open \ - an issue with an MWE for this case." - tile_assignment = reshape(device_list, sharding.tile_assignment_dimensions...) - tile_dimensions = div.(array_size, sharding.tile_assignment_dimensions) + tile_assignment = permutedims( + reshape(device_list, sharding.tile_assignment_dimensions...), + reverse(1:(length(sharding.tile_assignment_dimensions))), + ) + + if sharding.replicate_on_last_tile_dim + actual_tile_assignment_dimensions = size(tile_assignment)[2:end] + else + actual_tile_assignment_dimensions = size(tile_assignment) + end + tile_dimensions = div.(array_size, actual_tile_assignment_dimensions) mesh_devices = reshape([mesh.device_ids...], mesh.shape) # Match array dimensions to mesh axes by comparing device sequences used_axes = Set{Int}() partition_spec = ntuple(N) do dim - if dim <= length(sharding.tile_assignment_dimensions) && - sharding.tile_assignment_dimensions[dim] > 1 - tile_seq = __get_device_sequence(tile_assignment, dim) + if dim <= length(actual_tile_assignment_dimensions) && + actual_tile_assignment_dimensions[dim] > 1 + tile_seq = __get_device_sequence( + tile_assignment, dim + sharding.replicate_on_last_tile_dim + ) # For each unused mesh axis with matching size for (axis_idx, axis_name) in enumerate(mesh.axis_names) @@ -187,10 +195,16 @@ function compute_array_indices_and_partition_spec( device_to_array_indices = map(mesh.device_ids) do device_id tile_index = findfirst(==(device_id), tile_assignment) - @assert tile_index !== nothing "Device ID $device_id not found in tile assignment $tile_assignment" - tile_start = (tile_index.I .- 1) .* tile_dimensions .+ 1 - tile_end = tile_index.I .* tile_dimensions - ntuple(i -> tile_start[i]:tile_end[i], N) + @assert tile_index !== nothing "Device ID $device_id not found in tile \ + assignment $tile_assignment" + index_tup = if !sharding.replicate_on_last_tile_dim + Tuple(tile_index.I) + else + Tuple(tile_index.I[2:end]) + end + tile_start = (index_tup .- 1) .* tile_dimensions .+ 1 + tile_end = index_tup .* tile_dimensions + return ntuple(i -> tile_start[i]:tile_end[i], N) end return device_to_array_indices, partition_spec diff --git a/test/sharding.jl b/test/sharding.jl index 1afd03dec1..0d235f9f65 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -72,12 +72,15 @@ predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data")) w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing)) - samples = ConcreteRArray(rand(Float32, 3, 12); sharding=samples_sharding) - w1 = ConcreteRArray(rand(Float32, 4, 3); sharding=w1_sharding) - w2 = ConcreteRArray(rand(Float32, 2, 4)) + samples = reshape(collect(Float32, 1:84), 7, 12) + w1 = reshape(collect(Float32, 1:28), 4, 7) + w2 = reshape(collect(Float32, 1:32), 8, 4) + + samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) + w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) + w2_ra = Reactant.to_rarray(w2) if !fake_run - @test Array(@jit(predict(samples, w1, w2))) ≈ - predict(Array(samples), Array(w1), Array(w2)) + @test Array(@jit(predict(samples_ra, w1_ra, w2_ra))) ≈ predict(samples, w1, w2) end end From f6f814da413bc150654235718f0a2385acc81324 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Feb 2025 06:53:14 -0600 Subject: [PATCH 02/14] chore: bump jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b2f974b5f7..35a7ac0452 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.5" -Reactant_jll = "0.0.66" +Reactant_jll = "0.0.67" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From 9e2b1f101376996b88a10651b12cd30a67cc21a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Feb 2025 08:17:19 -0600 Subject: [PATCH 03/14] fix: handle dimensions --- deps/ReactantExtra/BUILD | 1 + src/xla/Sharding.jl | 38 +++++++++++++++++++++----------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c7ea6b18aa..3923e356cb 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -466,6 +466,7 @@ cc_library( "-Wl,-exported_symbol,_BufferNDimensions", "-Wl,-exported_symbol,_BufferPrimitiveType", "-Wl,-exported_symbol,_PjRtLoadedExecutableGetOuputShardings", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetParameterShardings", "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_reactant_*", ]}), diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 7f5f6a175b..292bf80d7e 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -145,7 +145,6 @@ function compute_array_indices_and_partition_spec( ) elseif sharding.type == OpShardingType.Other # Other: Tiled sharding - # Reshape tile_assignment_devices into tile_assignment_dimensions device_list = generate_device_list(sharding) sorted_mesh_devices = sort(collect(Int64, mesh.device_ids)) @assert sort(device_list) == sorted_mesh_devices "Mismatched devices list: \ @@ -153,36 +152,41 @@ function compute_array_indices_and_partition_spec( $(mesh.device_ids)" @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ Open an issue with an MWE for this case." + # Handle layout transformation + dims_order = if !isempty(sharding.layout_minor_to_major) + sharding.layout_minor_to_major + else + collect(1:length(sharding.tile_assignment_dimensions)) + end - tile_assignment = permutedims( - reshape(device_list, sharding.tile_assignment_dimensions...), - reverse(1:(length(sharding.tile_assignment_dimensions))), + # Reshape considering column-major order and layout + tile_assignment = reshape( + device_list, reverse(sharding.tile_assignment_dimensions)... ) - if sharding.replicate_on_last_tile_dim - actual_tile_assignment_dimensions = size(tile_assignment)[2:end] - else - actual_tile_assignment_dimensions = size(tile_assignment) + # Apply layout transformation + if !isempty(dims_order) + tile_assignment = permutedims(tile_assignment, dims_order) end - tile_dimensions = div.(array_size, actual_tile_assignment_dimensions) + # Handle replication dimension + tile_dims = size(tile_assignment)[(1 + sharding.replicate_on_last_tile_dim):end] + + # Calculate tile sizes + tile_sizes = div.(array_size, tile_dims) mesh_devices = reshape([mesh.device_ids...], mesh.shape) - # Match array dimensions to mesh axes by comparing device sequences + # Match dimensions to mesh axes used_axes = Set{Int}() partition_spec = ntuple(N) do dim - if dim <= length(actual_tile_assignment_dimensions) && - actual_tile_assignment_dimensions[dim] > 1 + if dim <= length(tile_dims) && tile_dims[dim] > 1 tile_seq = __get_device_sequence( tile_assignment, dim + sharding.replicate_on_last_tile_dim ) - # For each unused mesh axis with matching size for (axis_idx, axis_name) in enumerate(mesh.axis_names) if axis_idx ∉ used_axes && size(mesh_devices, axis_idx) == length(tile_seq) mesh_seq = __get_device_sequence(mesh_devices, axis_idx) - - # Check if sequences match (allowing for reversal) if tile_seq == mesh_seq || tile_seq == reverse(mesh_seq) push!(used_axes, axis_idx) return axis_name @@ -202,8 +206,8 @@ function compute_array_indices_and_partition_spec( else Tuple(tile_index.I[2:end]) end - tile_start = (index_tup .- 1) .* tile_dimensions .+ 1 - tile_end = index_tup .* tile_dimensions + tile_start = (index_tup .- 1) .* tile_sizes .+ 1 + tile_end = index_tup .* tile_sizes return ntuple(i -> tile_start[i]:tile_end[i], N) end From 58c393f09d9e4a92abf4c270a55b45bf79c44a1f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 14 Feb 2025 10:25:00 -0600 Subject: [PATCH 04/14] feat: access hlo module from julia --- deps/ReactantExtra/API.cpp | 28 +++++++++++++++++++++++ deps/ReactantExtra/BUILD | 5 +++++ docs/src/api/api.md | 1 + src/Compiler.jl | 44 +++++++++++++++++++++++++++++++++++-- src/Reactant.jl | 18 +++++++++++++-- src/xla/HloModule.jl | 20 +++++++++++++++++ src/xla/LoadedExecutable.jl | 28 ++++++++++++++++++++++- src/xla/XLA.jl | 1 + 8 files changed, 140 insertions(+), 5 deletions(-) create mode 100644 src/xla/HloModule.jl diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index adf50ef80d..47a59ec56c 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1035,6 +1035,14 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, } } +extern "C" int PjRtLoadedExecutableNumReplicas(PjRtLoadedExecutable *exec) { + return exec->num_replicas(); +} + +extern "C" int PjRtLoadedExecutableNumPartitions(PjRtLoadedExecutable *exec) { + return exec->num_partitions(); +} + void prepareRegistry(mlir::DialectRegistry ®istry); extern "C" void RegisterDialects(MlirContext cctx) { @@ -1356,3 +1364,23 @@ ifrt_CopyArrayToHostBuffer(HeldValue> *array, return new FutureType( (*array)->CopyToHostBuffer(data, std::nullopt, semantics)); } + +extern "C" void +PjRtLoadedExecutableGetHloModules(xla::PjRtLoadedExecutable *exec, + void **hlo_modules, int32_t *nmodules) { + auto hlo_modules_vec = MyValueOrThrow(exec->GetHloModules()); + *nmodules = hlo_modules_vec.size(); + for (int i = 0; i < *nmodules; i++) { + hlo_modules[i] = reactant::capture(hlo_modules_vec[i]); + } +} + +extern "C" const char * +HloModuleToString(HeldValue> *hlo_module) { + return cstr_from_string(hlo_module->obj()->ToString()); +} + +extern "C" void +FreeHloModule(HeldValue> *hlo_module) { + delete hlo_module; +} diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 3923e356cb..551d93cd43 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -467,6 +467,11 @@ cc_library( "-Wl,-exported_symbol,_BufferPrimitiveType", "-Wl,-exported_symbol,_PjRtLoadedExecutableGetOuputShardings", "-Wl,-exported_symbol,_PjRtLoadedExecutableGetParameterShardings", +"-Wl,-exported_symbol,_PjRtLoadedExecutableGetHloModules", +"-Wl,-exported_symbol,_HloModuleToString", +"-Wl,-exported_symbol,_FreeHloModule", +"-Wl,-exported_symbol,_PjRtLoadedExecutableNumReplicas", +"-Wl,-exported_symbol,_PjRtLoadedExecutableNumPartitions", "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_reactant_*", ]}), diff --git a/docs/src/api/api.md b/docs/src/api/api.md index ce61e62322..d8746bddde 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -26,6 +26,7 @@ within_compile ```@docs @code_hlo @code_mhlo +@code_xla ``` ## Profile XLA diff --git a/src/Compiler.jl b/src/Compiler.jl index c7c0157766..123785e4fd 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -883,8 +883,48 @@ macro code_mhlo(args...) compile_expr, (; compiled) = compile_call_expr( __module__, compile_xla, default_options, args... ) - return esc(:($(compile_expr); - $(first)($(compiled)))) + #! format: off + return esc( + :( + $(compile_expr); + $(first)($(compiled)) + ) + ) + #! format: on +end + +""" + @code_xla [optimize = ...] [no_nan = ] f(args...) + +Similar to `@code_hlo`, but prints the HLO module. +""" +macro code_xla(args...) + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :client => nothing + ) + compile_expr, (; compiled) = compile_call_expr( + __module__, compile_xla, default_options, args... + ) + #! format: off + return esc( + :( + $(compile_expr); + exec = $(compiled)[2]; + hlo_modules = $(XLA.get_hlo_modules)(exec); + if length(hlo_modules) == 1 + hlo_module = only(hlo_modules) + println(hlo_module) + else + println("HLO modules:") + for (i, hlo_module) in enumerate(hlo_modules) + println("Partition $i:") + println(hlo_module) + println() + end + end + ) + ) + #! format: on end """ diff --git a/src/Reactant.jl b/src/Reactant.jl index cef6bc9331..a83152d570 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -149,9 +149,23 @@ function Enzyme.make_zero( end using .Compiler: - @compile, @code_hlo, @code_mhlo, @jit, traced_getfield, create_result, compile + @compile, + @code_hlo, + @code_mhlo, + @jit, + @code_xla, + traced_getfield, + create_result, + compile export ConcreteRArray, - ConcreteRNumber, @compile, @code_hlo, @code_mhlo, @jit, @trace, within_compile + ConcreteRNumber, + @compile, + @code_hlo, + @code_mhlo, + @code_xla, + @jit, + @trace, + within_compile const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() diff --git a/src/xla/HloModule.jl b/src/xla/HloModule.jl new file mode 100644 index 0000000000..e9f2eb12e0 --- /dev/null +++ b/src/xla/HloModule.jl @@ -0,0 +1,20 @@ +mutable struct HloModule + ptr::Ptr{Cvoid} + + function HloModule(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_hlo_module, new(ptr)) + end +end + +function free_hlo_module(hlo_module) + @ccall MLIR.API.mlir_c.FreeHloModule(hlo_module.ptr::Ptr{Cvoid})::Cvoid +end + +function Base.show(io::IO, hlo_module::HloModule) + GC.@preserve hlo_module begin + str = @ccall MLIR.API.mlir_c.HloModuleToString(hlo_module.ptr::Ptr{Cvoid})::Cstring + end + print(io, unsafe_string(str)) + return nothing +end diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 5bd1321e9b..4e23ee63dc 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -16,6 +16,17 @@ mutable struct LoadedExecutable end end +for (jlop, xlaop) in ( + (:num_replicas, :PjRtLoadedExecutableNumReplicas), + (:num_partitions, :PjRtLoadedExecutableNumPartitions), +) + @eval function $(jlop)(exec::LoadedExecutable) + GC.@preserve exec begin + return @ccall MLIR.API.mlir_c.$(xlaop)(exec.exec::Ptr{Cvoid})::Cint + end + end +end + function client(exec::LoadedExecutable) GC.@preserve exec begin return Client( @@ -215,7 +226,6 @@ function Compile( mod::MLIR.IR.Module; is_sharded::Bool=false, mesh_ids::Vector{Int64}=Int64[], - # mesh_shape::Vector{Int64}=Int64[], num_outputs::Int64, num_parameters::Int64, ) @@ -258,3 +268,19 @@ for (jlop, xlaop, field) in ( return map(OpSharding ∘ getindex, jl_op_shardings) end end + +function get_hlo_modules(exec::LoadedExecutable) + # If we had compiled with MPMD then we would need all the partitions to get hlo_modules + # but if we used SPMD we get only 1 module. To be safe we allocate for all the modules + # and use the ones assigned to by XLA + hlo_modules = Ref{NTuple{Int64(num_partitions(exec)),Ptr{Cvoid}}}() + nmodules = Ref{Int32}(0) + GC.@preserve exec hlo_modules begin + @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetHloModules( + exec.exec::Ptr{Cvoid}, + hlo_modules::Ptr{Ptr{Cvoid}}, + nmodules::Ptr{Cint}, + )::Cvoid + end + return map(HloModule, hlo_modules[][1:Int(nmodules[])]) +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 2fe4cc9f46..3b719f31dc 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -28,6 +28,7 @@ include("Future.jl") include("Buffer.jl") include("Stats.jl") include("Utils.jl") +include("HloModule.jl") const backends = Dict{String,Client}() const default_backend = Ref{Client}() From b40e4b034dd1f07902b89ff78d7d19e19d53b705 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Feb 2025 08:06:12 -0600 Subject: [PATCH 05/14] fix: roundtrip from OpSharding --- src/Sharding.jl | 136 ++++++++++++++++++++++-------------- src/xla/LoadedExecutable.jl | 9 +-- src/xla/Sharding.jl | 17 ++--- test/sharding.jl | 9 +++ 4 files changed, 108 insertions(+), 63 deletions(-) diff --git a/src/Sharding.jl b/src/Sharding.jl index 2c36ae5f52..0a139d0d0d 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -90,67 +90,101 @@ struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding end end -function (sharding::NamedSharding)(client::XLA.Client, device, x::Number) +function named_sharding_to_opsharding(sharding::NamedSharding, shape::Dims) (; mesh, partition_spec) = sharding - @assert length(partition_spec) == 0 + @assert length(partition_spec) == length(shape) - data = map(mesh.device_ids) do device_id - return XLA.AsyncBuffer( - XLA.ArrayFromHostBuffer(client, fill(x), XLA.device_ordinal(client, device_id)), - nothing, + # Fast Path for replicating the input across all devices + if all(Base.Fix2(===, nothing), partition_spec) + return XLA.OpSharding( + XLA.OpShardingType.Replicated, + Int64[], + Int64[], + false, + XLA.OpShardingType.T[], + Int64[], + Int64[], + Int64[], + Int32[], + false, + -1, + XLA.ShardGroupType.As, ) end - return data, ShardInfo(sharding, ntuple(Returns(()), length(mesh))) -end -function (sharding::NamedSharding)(client::XLA.Client, ::Nothing, x::AbstractArray) - (; mesh, partition_spec) = sharding - @assert length(partition_spec) == ndims(x) + tile_dims = map(Base.Fix1(size, mesh), partition_spec) + num_tiles_before_replication = prod(tile_dims) + total_devices = length(mesh.device_ids) + replication_factor = cld(total_devices, num_tiles_before_replication) + replicate_on_last_tile_dim = replication_factor > 1 + replicate_on_last_tile_dim && (tile_dims = (replication_factor, tile_dims...)) + + # Create tile assignment array + tile_assignment = Array{Int}(undef, tile_dims...) + devices = reshape(collect(mesh.device_ids), size(mesh)) + + # Find axes not used in partition_spec for replication + unused_axes = filter(axis -> axis ∉ partition_spec, mesh.axis_names) + unused_dims = map(axis -> size(mesh, axis), unused_axes) + replication_indices = CartesianIndices(Tuple(unused_dims)) + + # Fill tile assignment array + for indices in CartesianIndices(tile_assignment) + index_tuple = Tuple(indices) + actual_indices = replicate_on_last_tile_dim ? index_tuple[2:end] : index_tuple + repl_idx = replicate_on_last_tile_dim ? index_tuple[1] : 1 + + # Initialize device index array + device_index = ones(Int, ndims(mesh)) + + # Map partition dimensions to device indices + for (tile_idx, (pspec, dim_size)) in enumerate(zip(partition_spec, shape)) + if pspec !== nothing + mesh_axis = findfirst(==(Symbol(pspec)), mesh.axis_names) + if mesh_axis !== nothing + device_index[mesh_axis] = actual_indices[tile_idx] + end + end + end - # Fast Path for replicating the input across all devices - if all(Base.Fix2(===, nothing), partition_spec) - data = map(mesh.device_ids) do device_id - return XLA.AsyncBuffer( - XLA.ArrayFromHostBuffer( - client, - x, - XLA.ClientGetAddressableDevice( - client, XLA.device_ordinal(client, device_id) - ), - ), - nothing, - ) + # Handle replication for unused axes + for (i, axis) in enumerate(unused_axes) + axis_idx = findfirst(==(axis), mesh.axis_names) + if axis_idx !== nothing + device_index[axis_idx] = replication_indices[repl_idx][i] + end end - device_to_array_slices = ntuple( - Returns(ntuple(i -> 1:size(x, i), ndims(x))), length(mesh) - ) - return data, ShardInfo(sharding, device_to_array_slices) - end - ndevices = map(Base.Fix1(size, mesh), partition_spec) - for (sz, ndevice) in zip(size(x), ndevices) - @assert sz % ndevice == 0 "$(size(x)) must be divisible by $(ndevices)" + # Assign device to tile + tile_assignment[indices] = devices[device_index...] end - strides = size(x) .÷ ndevices - slices = Array{NTuple{ndims(x),UnitRange{Int64}},ndims(x)}(undef, ndevices) - for idx in CartesianIndices(slices) - idx_tup = Tuple(idx) - slices[idx] = Tuple( - (i1 + 1):i2 for (i1, i2) in zip((idx_tup .- 1) .* strides, idx_tup .* strides) - ) - end + return XLA.OpSharding( + XLA.OpShardingType.Other, + Int64[], + Int64[], + replicate_on_last_tile_dim, + XLA.OpShardingType.T[], + collect(Int64, size(tile_assignment)), + vec(tile_assignment), + Int64[], + Int32[], + false, + -1, + XLA.ShardGroupType.As, + ) +end - device_to_array_slices = Array{eltype(slices),ndims(mesh)}(undef, size(mesh)) - for idx in CartesianIndices(device_to_array_slices) - idx_tup = Tuple(idx) - slice_idx = ones(Int, ndims(slices)) - for (axis_name, idxᵢ) in zip(mesh.axis_names, idx_tup) - dim = findfirst(==(axis_name), sharding.partition_spec) - dim !== nothing && (slice_idx[dim] = idxᵢ) - end - device_to_array_slices[idx] = slices[CartesianIndex(slice_idx...)] - end +function (sharding::NamedSharding)( + client::XLA.Client, ::Nothing, x::Union{AbstractArray,Number} +) + (; mesh, partition_spec) = sharding + @assert length(partition_spec) == ndims(x) + + opsharding = named_sharding_to_opsharding(sharding, size(x)) + device_to_array_slices, _ = XLA.compute_array_indices_and_partition_spec( + opsharding, size(x), mesh + ) data = ntuple(length(mesh)) do i XLA.AsyncBuffer( @@ -165,7 +199,7 @@ function (sharding::NamedSharding)(client::XLA.Client, ::Nothing, x::AbstractArr ) end - return data, ShardInfo(sharding, Tuple(vec(device_to_array_slices))) + return data, ShardInfo(sharding, device_to_array_slices) end # Given Sharding + Array --> ShardInfo diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 4e23ee63dc..261f0fe549 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -213,9 +213,12 @@ end ntuple(Val(K)) do i Base.@_inline_meta idx = (i - 1) * n_outs + j - return AsyncBuffer( + zzz = AsyncBuffer( Buffer(outputs[idx]), future ? Future(future_res[idx]) : nothing ) + @show convert(Array{Float32}, zzz) + @show i, j, zzz + return zzz end end end @@ -277,9 +280,7 @@ function get_hlo_modules(exec::LoadedExecutable) nmodules = Ref{Int32}(0) GC.@preserve exec hlo_modules begin @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetHloModules( - exec.exec::Ptr{Cvoid}, - hlo_modules::Ptr{Ptr{Cvoid}}, - nmodules::Ptr{Cint}, + exec.exec::Ptr{Cvoid}, hlo_modules::Ptr{Ptr{Cvoid}}, nmodules::Ptr{Cint} )::Cvoid end return map(HloModule, hlo_modules[][1:Int(nmodules[])]) diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 292bf80d7e..7495dd6daf 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -112,12 +112,14 @@ function generate_device_list(sharding::OpSharding) # Generate device IDs using iota num_devices = prod(sharding.iota_reshape_dims) iota_devices = collect( - Int64, reshape(0:(num_devices - 1), sharding.iota_reshape_dims...) + Int64, reshape(0:(num_devices - 1), reverse(sharding.iota_reshape_dims)...) ) # Permute the iota array if iota_transpose_perm is provided if !isempty(sharding.iota_transpose_perm) - iota_devices = permutedims(iota_devices, Tuple(sharding.iota_transpose_perm)) + iota_devices = permutedims( + iota_devices, reverse(Tuple(sharding.iota_transpose_perm)) + ) end # Flatten the permuted iota array to get tile_assignment_devices @@ -199,13 +201,12 @@ function compute_array_indices_and_partition_spec( device_to_array_indices = map(mesh.device_ids) do device_id tile_index = findfirst(==(device_id), tile_assignment) - @assert tile_index !== nothing "Device ID $device_id not found in tile \ - assignment $tile_assignment" - index_tup = if !sharding.replicate_on_last_tile_dim - Tuple(tile_index.I) - else - Tuple(tile_index.I[2:end]) + if tile_index === nothing + error("Device ID $device_id not found in tile assignment \ + $tile_assignment") end + index_tup = Tuple(tile_index.I) + sharding.replicate_on_last_tile_dim && (index_tup = index_tup[2:end]) tile_start = (index_tup .- 1) .* tile_sizes .+ 1 tile_end = index_tup .* tile_sizes return ntuple(i -> tile_start[i]:tile_end[i], N) diff --git a/test/sharding.jl b/test/sharding.jl index 0d235f9f65..6f6b498027 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -58,6 +58,8 @@ end predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) +fn_test2(x) = x .+ x' + @testset "Sharding Across 8 Devices" begin if length(addressable_devices) ≥ 8 mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("data", "model")) @@ -69,6 +71,13 @@ predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) fake_run = true end + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray(x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))) + + if !fake_run + @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + end + samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data")) w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing)) From d887922a63a50253823ba1c8007d93ccb947f855 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Feb 2025 08:13:13 -0600 Subject: [PATCH 06/14] fix: minor patch --- deps/ReactantExtra/API.cpp | 7 ++++--- src/Sharding.jl | 3 +++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 4d09d10af4..09c66bbe6d 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1045,8 +1045,9 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, } // Handle returned futures - *futures = returned_futures.has_value(); - if (*futures) { + auto future_val = returned_future.has_value(); + *futures = future_val; + if (future_val) { if (returned_futures->size() != num_mesh_ids) { ReactantThrowError((" returned_futures->size()=" + std::to_string(returned_futures->size()) + @@ -1061,7 +1062,7 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, int64_t mesh_id = mesh_ids[device_idx]; for (int result_idx = 0; result_idx < num_results; ++result_idx) { int flat_index = mesh_id * num_results + result_idx; - if (*futures) { + if (future_val) { future_results[flat_index] = new FutureType(std::move((*returned_futures)[mesh_id])); } diff --git a/src/Sharding.jl b/src/Sharding.jl index 3b72b52b03..b650992f91 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -65,6 +65,9 @@ function (::NoSharding)(client::XLA.Client, device, x::Union{AbstractArray,Numbe return (buffer,), ShardInfo(NoSharding(), nothing) end +# TODO: At the core create a DimSharding Type. We can convert the other sharding types to +# this type + # XXX: multiple axes partitioning -- supported by shardy (not in Jax I think) struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding mesh::Mesh{D1,D2} From cbc550d310ceaa219d5293d954a8a1b12554539f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Feb 2025 08:15:41 -0600 Subject: [PATCH 07/14] chore: bump jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4b76884b42..4c6062e94c 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.5" -Reactant_jll = "0.0.67" +Reactant_jll = "0.0.68" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From b3288dedbf34d7a43e23aa66acb5f67369852097 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Feb 2025 13:23:16 -0600 Subject: [PATCH 08/14] fix: more fixes --- deps/ReactantExtra/API.cpp | 6 +++--- src/Sharding.jl | 4 ++++ src/xla/LoadedExecutable.jl | 5 +---- src/xla/Sharding.jl | 19 +++++-------------- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 09c66bbe6d..304c05458e 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1045,7 +1045,7 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, } // Handle returned futures - auto future_val = returned_future.has_value(); + auto future_val = returned_futures.has_value(); *futures = future_val; if (future_val) { if (returned_futures->size() != num_mesh_ids) { @@ -1062,11 +1062,11 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, int64_t mesh_id = mesh_ids[device_idx]; for (int result_idx = 0; result_idx < num_results; ++result_idx) { int flat_index = mesh_id * num_results + result_idx; + op_results[flat_index] = results[mesh_id][result_idx].release(); if (future_val) { future_results[flat_index] = - new FutureType(std::move((*returned_futures)[mesh_id])); + new FutureType((*returned_futures)[mesh_id]); } - op_results[flat_index] = results[mesh_id][result_idx].release(); } } } diff --git a/src/Sharding.jl b/src/Sharding.jl index b650992f91..9ff764feb5 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -163,6 +163,10 @@ function named_sharding_to_opsharding(sharding::NamedSharding, shape::Dims) tile_assignment[indices] = devices[device_index...] end + tile_assignment = permutedims( + tile_assignment, reverse(collect(Int64, 1:ndims(tile_assignment))) + ) + return XLA.OpSharding( XLA.OpShardingType.Other, Int64[], diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 261f0fe549..14a6de7569 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -213,12 +213,9 @@ end ntuple(Val(K)) do i Base.@_inline_meta idx = (i - 1) * n_outs + j - zzz = AsyncBuffer( + return AsyncBuffer( Buffer(outputs[idx]), future ? Future(future_res[idx]) : nothing ) - @show convert(Array{Float32}, zzz) - @show i, j, zzz - return zzz end end end diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 7495dd6daf..6b40b2f8eb 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -154,23 +154,14 @@ function compute_array_indices_and_partition_spec( $(mesh.device_ids)" @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ Open an issue with an MWE for this case." - # Handle layout transformation - dims_order = if !isempty(sharding.layout_minor_to_major) - sharding.layout_minor_to_major - else - collect(1:length(sharding.tile_assignment_dimensions)) - end + @assert isempty(sharding.layout_minor_to_major) "Layout transformation is not \ + supported yet!" - # Reshape considering column-major order and layout - tile_assignment = reshape( - device_list, reverse(sharding.tile_assignment_dimensions)... + tile_assignment = reshape(device_list, sharding.tile_assignment_dimensions...) + tile_assignment = permutedims( + tile_assignment, reverse(collect(Int64, 1:ndims(tile_assignment))) ) - # Apply layout transformation - if !isempty(dims_order) - tile_assignment = permutedims(tile_assignment, dims_order) - end - # Handle replication dimension tile_dims = size(tile_assignment)[(1 + sharding.replicate_on_last_tile_dim):end] From 7db9eaefb3ea2d1518671968ffc0e4288be642a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Feb 2025 22:25:50 -0600 Subject: [PATCH 09/14] fix: more progress towards correct results --- deps/ReactantExtra/API.cpp | 8 +- src/Compiler.jl | 47 +++++---- src/Ops.jl | 2 +- src/Sharding.jl | 194 ++++++++++++++++++------------------ src/xla/LoadedExecutable.jl | 4 +- src/xla/Sharding.jl | 177 +++++++++++++++++++++++--------- test/sharding.jl | 37 +++++-- 7 files changed, 281 insertions(+), 188 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 304c05458e..bd114c376b 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1003,11 +1003,11 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, .c_str()); } - argument_handles[mesh_id].reserve(num_args); + argument_handles[device_idx].reserve(num_args); for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { // Assuming op_args is a flat array of size num_devices * num_args // where arguments for each device are contiguous - argument_handles[mesh_id].push_back( + argument_handles[device_idx].push_back( op_args[mesh_id * num_args + arg_idx]); } } @@ -1062,10 +1062,10 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, int64_t mesh_id = mesh_ids[device_idx]; for (int result_idx = 0; result_idx < num_results; ++result_idx) { int flat_index = mesh_id * num_results + result_idx; - op_results[flat_index] = results[mesh_id][result_idx].release(); + op_results[flat_index] = results[device_idx][result_idx].release(); if (future_val) { future_results[flat_index] = - new FutureType((*returned_futures)[mesh_id]); + new FutureType((*returned_futures)[device_idx]); } } } diff --git a/src/Compiler.jl b/src/Compiler.jl index 3fbd2e04fd..f21e689dfd 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1076,7 +1076,21 @@ function codegen_flatten!( if is_sharded carg = inv_seen_args[arg] + condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding( + linear_parameter_shardings[i] + ) if Reactant.Sharding.is_sharded(carg) + # Check if the sharding provided is same as the one we have + arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding( + Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg)) + ) + + # XXX: Change to error + if arg_condensed_op_sharding != condensed_op_sharding + @warn "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." + end + # @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." + push!(flatten_code, :($usbuf = $flatcode.data)) for j in 1:length(mesh) sbuf = Symbol(:sbuf_, i, "_", j) @@ -1085,12 +1099,13 @@ function codegen_flatten!( end else push!(flatten_code, :($usbuf = $flatcode)) - device_to_array_slices, _ = XLA.compute_array_indices_and_partition_spec( - linear_parameter_shardings[i], size(carg), mesh + device_to_array_slices = XLA.sharding_to_concrete_array_indices( + condensed_op_sharding, size(carg), mesh ) + device_ids = vec(mesh) for j in 1:length(mesh) buf = Symbol(:buf_, i, :_, j) - device_id = mesh.device_ids[j] + device_id = device_ids[j] slice = device_to_array_slices[j] push!( flatten_code, @@ -1449,12 +1464,9 @@ function compile_xla(f, args; client=nothing, kwargs...) ) # compile MLIR module to XLA executable + device_ids = mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[] mlir_fn_res.is_sharded && (device = nothing) - mesh_ids = if mlir_fn_res.is_sharded - collect(Int64, mlir_fn_res.sharding_mesh.device_ids) - else - Int64[] - end + exec = XLA.Compile( client, device, @@ -1462,7 +1474,7 @@ function compile_xla(f, args; client=nothing, kwargs...) num_outputs=length(mlir_fn_res.linear_results), num_parameters=length(mlir_fn_res.linear_args), mlir_fn_res.is_sharded, - mesh_ids, + device_ids, ) return mod, exec, mlir_fn_res, device, client @@ -1487,14 +1499,6 @@ function compile(f, args; sync=false, kwargs...) result_stores = Dict{Tuple,Symbol}() path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing - linear_parameter_shardings = if mlir_fn_res.is_sharded - # The sharding info here is exclusively used for the parameters that weren't sharded - # before hand - XLA.get_parameter_shardings(exec) - else - nothing - end - # generate Julia `Thunk` code flatten_arg_names, flatten_code = codegen_flatten!( linear_args, @@ -1502,7 +1506,7 @@ function compile(f, args; sync=false, kwargs...) result_stores, mlir_fn_res.is_sharded, mlir_fn_res.sharding_mesh, - linear_parameter_shardings, + XLA.get_parameter_shardings(exec), client, ) @@ -1513,15 +1517,10 @@ function compile(f, args; sync=false, kwargs...) donated_args_mask, length(linear_results), mlir_fn_res.is_sharded, - if mlir_fn_res.is_sharded - collect(Int64, mlir_fn_res.sharding_mesh.device_ids) - else - Int64[] - end, + mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[], ) linear_result_shard_info = if mlir_fn_res.is_sharded - # Generate a tuple of DeviceToArraySlices and PartitionSpecs output_shardings = XLA.get_output_shardings(exec) XLA.compute_array_indices_and_partition_spec.( output_shardings, diff --git a/src/Ops.jl b/src/Ops.jl index da79daab72..803732f62f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2143,7 +2143,7 @@ end return mesh( mod, [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], - collect(Int64, m.device_ids); + vec(m); location, ) end diff --git a/src/Sharding.jl b/src/Sharding.jl index 9ff764feb5..cccd4e3b9a 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -2,11 +2,6 @@ module Sharding using ..Reactant: Reactant, XLA -# NOTE: PjRt doesn't provide a native sharding mechanism, so this file implements sharding -# at the julia level. With our migration to IFRt, we should be able to rewrite this -# logic to directly use the sharded arrays from IFRt. This would also simplify our -# logic of storing multiple arrays in ConcreteRArray struct - struct Mesh{D,ND} device_ids::NTuple{ND,Int} shape::Dims{D} @@ -36,6 +31,15 @@ struct Mesh{D,ND} end end +Base.vec(mesh::Mesh) = vec(device_ids(mesh)) + +function device_ids(mesh::Mesh) + # XXX: Do we need to permute the device ids? + return permutedims( + reshape(collect(Int64, mesh.device_ids), size(mesh)...), reverse(1:ndims(mesh)) + ) +end + Base.length(::Mesh{D,ND}) where {D,ND} = ND Base.ndims(::Mesh{D}) where {D} = D @@ -65,8 +69,8 @@ function (::NoSharding)(client::XLA.Client, device, x::Union{AbstractArray,Numbe return (buffer,), ShardInfo(NoSharding(), nothing) end -# TODO: At the core create a DimSharding Type. We can convert the other sharding types to -# this type +# TODO: At the core we should have an HloSharding Type that doesn't need to store the +# partition spec and other details # XXX: multiple axes partitioning -- supported by shardy (not in Jax I think) struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding @@ -94,105 +98,16 @@ struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding end end -function named_sharding_to_opsharding(sharding::NamedSharding, shape::Dims) - (; mesh, partition_spec) = sharding - @assert length(partition_spec) == length(shape) - - # Fast Path for replicating the input across all devices - if all(Base.Fix2(===, nothing), partition_spec) - return XLA.OpSharding( - XLA.OpShardingType.Replicated, - Int64[], - Int64[], - false, - XLA.OpShardingType.T[], - Int64[], - Int64[], - Int64[], - Int32[], - false, - -1, - XLA.ShardGroupType.As, - ) - end - - tile_dims = map(Base.Fix1(size, mesh), partition_spec) - num_tiles_before_replication = prod(tile_dims) - total_devices = length(mesh.device_ids) - replication_factor = cld(total_devices, num_tiles_before_replication) - replicate_on_last_tile_dim = replication_factor > 1 - replicate_on_last_tile_dim && (tile_dims = (replication_factor, tile_dims...)) - - # Create tile assignment array - tile_assignment = Array{Int}(undef, tile_dims...) - devices = reshape(collect(mesh.device_ids), size(mesh)) - - # Find axes not used in partition_spec for replication - unused_axes = filter(axis -> axis ∉ partition_spec, mesh.axis_names) - unused_dims = map(axis -> size(mesh, axis), unused_axes) - replication_indices = CartesianIndices(Tuple(unused_dims)) - - # Fill tile assignment array - for indices in CartesianIndices(tile_assignment) - index_tuple = Tuple(indices) - actual_indices = replicate_on_last_tile_dim ? index_tuple[2:end] : index_tuple - repl_idx = replicate_on_last_tile_dim ? index_tuple[1] : 1 - - # Initialize device index array - device_index = ones(Int, ndims(mesh)) - - # Map partition dimensions to device indices - for (tile_idx, (pspec, dim_size)) in enumerate(zip(partition_spec, shape)) - if pspec !== nothing - mesh_axis = findfirst(==(Symbol(pspec)), mesh.axis_names) - if mesh_axis !== nothing - device_index[mesh_axis] = actual_indices[tile_idx] - end - end - end - - # Handle replication for unused axes - for (i, axis) in enumerate(unused_axes) - axis_idx = findfirst(==(axis), mesh.axis_names) - if axis_idx !== nothing - device_index[axis_idx] = replication_indices[repl_idx][i] - end - end - - # Assign device to tile - tile_assignment[indices] = devices[device_index...] - end - - tile_assignment = permutedims( - tile_assignment, reverse(collect(Int64, 1:ndims(tile_assignment))) - ) - - return XLA.OpSharding( - XLA.OpShardingType.Other, - Int64[], - Int64[], - replicate_on_last_tile_dim, - XLA.OpShardingType.T[], - collect(Int64, size(tile_assignment)), - vec(tile_assignment), - Int64[], - Int32[], - false, - -1, - XLA.ShardGroupType.As, - ) -end - function (sharding::NamedSharding)( client::XLA.Client, ::Nothing, x::Union{AbstractArray,Number} ) (; mesh, partition_spec) = sharding @assert length(partition_spec) == ndims(x) - opsharding = named_sharding_to_opsharding(sharding, size(x)) device_to_array_slices, _ = XLA.compute_array_indices_and_partition_spec( - opsharding, size(x), mesh + XLA.CondensedOpSharding(ShardingWithShape(sharding, size(x))), size(x), mesh ) + devices_list = vec(mesh) data = ntuple(length(mesh)) do i XLA.AsyncBuffer( @@ -200,7 +115,7 @@ function (sharding::NamedSharding)( client, x[device_to_array_slices[i]...], XLA.ClientGetAddressableDevice( - client, XLA.device_ordinal(client, mesh.device_ids[i]) + client, XLA.device_ordinal(client, devices_list[i]) ), ), nothing, @@ -210,12 +125,81 @@ function (sharding::NamedSharding)( return data, ShardInfo(sharding, device_to_array_slices) end +struct ShardingWithShape{S,D} <: AbstractSharding + sharding::S + shape::D +end + +# XXX: we need to make the performance of this function better +function XLA.CondensedOpSharding(sharding_and_shape::ShardingWithShape{<:NamedSharding}) + (; sharding, shape) = sharding_and_shape + (; mesh, partition_spec) = sharding + @assert length(partition_spec) == length(shape) + + partition_spec = reverse(partition_spec) + shape = reverse(shape) + + array_mapping = __get_array_mapping(partition_spec) + mesh_axis_position = Dict(name => i for (i, name) in enumerate(mesh.axis_names)) + + replicated_mesh_axes = Tuple{Int64,Int64}[] + for (i, axis_name) in enumerate(mesh.axis_names) + if !haskey(array_mapping, axis_name) + push!(replicated_mesh_axes, (i, size(mesh, axis_name))) + end + end + + tile_assignment = device_ids(mesh) + + # Fast Path for replicating the input across all devices + if length(replicated_mesh_axes) == ndims(mesh) + return XLA.CondensedOpSharding{ndims(tile_assignment)}( + XLA.OpShardingType.Replicated, false, tile_assignment + ) + end + + # Calculate new mesh shape and permutation + mesh_permutation = Int[] + new_mesh_shape = ones(Int, length(shape)) + + # Sort array mapping by position to ensure consistent order + for (name, pos) in sort(collect(array_mapping); by=x -> x[2]) + new_mesh_shape[pos] *= size(mesh, name) + push!(mesh_permutation, mesh_axis_position[name]) + end + + # Handle replicated dimensions at the end + replicate_on_last_tile_dim = false + if !isempty(replicated_mesh_axes) + replicated_size = prod(last(axis) for axis in replicated_mesh_axes) + push!(new_mesh_shape, replicated_size) + append!(mesh_permutation, first.(replicated_mesh_axes)) + + tile_assignment = reshape(tile_assignment, new_mesh_shape...) + push!(mesh_permutation, length(mesh_permutation) + 1) + replicate_on_last_tile_dim = true + end + + permuted = permutedims(tile_assignment, mesh_permutation) + final_assignment = reshape(permuted, new_mesh_shape...) + + return XLA.CondensedOpSharding{ndims(final_assignment)}( + XLA.OpShardingType.Other, replicate_on_last_tile_dim, final_assignment + ) +end + # Given Sharding + Array --> ShardInfo struct ShardInfo{S,D} <: AbstractSharding sharding::S device_to_array_slices::D end +function XLA.CondensedOpSharding(sharding_and_shape::ShardingWithShape{<:ShardInfo}) + return XLA.CondensedOpSharding( + ShardingWithShape(sharding_and_shape.sharding.sharding, sharding_and_shape.shape) + ) +end + function Base.getproperty(sharding::ShardInfo, name::Symbol) name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) return getfield(sharding.sharding, name) @@ -249,4 +233,16 @@ function is_sharded(x::Number) return false end +function __get_array_mapping(partition_spec) + mapping = Dict{Symbol,Int64}() + for (i, axis) in enumerate(partition_spec) + axis === nothing && continue + axis isa Symbol && (axis = (axis,)) + for axis_name in axis + mapping[axis_name] = i + end + end + return mapping +end + end diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 14a6de7569..55a27974cb 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -225,12 +225,12 @@ function Compile( device::Union{Device,Nothing}, mod::MLIR.IR.Module; is_sharded::Bool=false, - mesh_ids::Vector{Int64}=Int64[], + device_ids::Vector{Int64}=Int64[], num_outputs::Int64, num_parameters::Int64, ) device_id = is_sharded ? Int64(-1) : Int64(device_ordinal(client, device)) - mesh_ids = Int64.(device_ordinal.((client,), mesh_ids)) + mesh_ids = Int64.(device_ordinal.((client,), device_ids)) GC.@preserve client mod begin exec = @ccall MLIR.API.mlir_c.ClientCompile( client.client::Ptr{Cvoid}, diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 6b40b2f8eb..bd1f0e9716 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -112,14 +112,12 @@ function generate_device_list(sharding::OpSharding) # Generate device IDs using iota num_devices = prod(sharding.iota_reshape_dims) iota_devices = collect( - Int64, reshape(0:(num_devices - 1), reverse(sharding.iota_reshape_dims)...) + Int64, reshape(0:(num_devices - 1), sharding.iota_reshape_dims...) ) # Permute the iota array if iota_transpose_perm is provided if !isempty(sharding.iota_transpose_perm) - iota_devices = permutedims( - iota_devices, reverse(Tuple(sharding.iota_transpose_perm)) - ) + iota_devices = permutedims(iota_devices, Tuple(sharding.iota_transpose_perm)) end # Flatten the permuted iota array to get tile_assignment_devices @@ -128,53 +126,149 @@ function generate_device_list(sharding::OpSharding) return sharding.tile_assignment_devices end -# Function to compute array indices for each device +function get_number_of_ways_dim_sharded(op_sharding::OpSharding) + op_sharding.type == OpShardingType.Replicated && return Int64[], 1 + + if op_sharding.replicate_on_last_tile_dim + return ( + op_sharding.tile_assignment_dimensions[1:(end - 1)], + op_sharding.tile_assignment_dimensions[end], + ) + end + return op_sharding.tile_assignment_dimensions, 1 +end + +function sharding_to_concrete_array_indices( + sharding::OpSharding, shape::Dims{N}, mesh +) where {N} + return sharding_to_concrete_array_indices(CondensedOpSharding(sharding), shape, mesh) +end + function compute_array_indices_and_partition_spec( sharding::OpSharding, array_size::Dims{N}, mesh ) where {N} - if sharding.type == OpShardingType.Replicated - # Replicated: All devices have the entire array + return compute_array_indices_and_partition_spec( + CondensedOpSharding(sharding), array_size, mesh + ) +end + +# This only stores the data that we currently support, and is useful for checking equality +# We would want to extend support to more of the fields at a later time +struct CondensedOpSharding{N} + type::OpShardingType.T + replicate_on_last_tile_dim::Bool + tile_assignment::Array{Int64,N} +end + +function Base.:(==)(a::CondensedOpSharding, b::CondensedOpSharding) + return a.type == b.type && + a.replicate_on_last_tile_dim == b.replicate_on_last_tile_dim && + a.tile_assignment == b.tile_assignment +end + +function CondensedOpSharding(sharding::OpSharding) + @assert isempty(sharding.last_tile_dims) "Last Tile dimensions are not supported \ + yet!" + @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ + Open an issue with an MWE for this case." + @assert isempty(sharding.layout_minor_to_major) "Layout transformation is not \ + supported yet!" + + if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal + tile_assignment = generate_device_list(sharding) + elseif sharding.type == OpShardingType.Other + tile_assignment = reshape( + generate_device_list(sharding), sharding.tile_assignment_dimensions... + ) + else + error("Invalid sharding type: $(sharding.type)") + end + + return CondensedOpSharding( + sharding.type, sharding.replicate_on_last_tile_dim, tile_assignment + ) +end + +function get_number_of_ways_dim_sharded(op_sharding::CondensedOpSharding{N}) where {N} + op_sharding.type == OpShardingType.Replicated && return Int64[], 1 + + if op_sharding.replicate_on_last_tile_dim + return ( + size(op_sharding.tile_assignment)[1:(N - 1)], + size(op_sharding.tile_assignment, N), + ) + end + return size(op_sharding.tile_assignment), 1 +end + +function sharding_to_concrete_array_indices( + sharding::CondensedOpSharding, shape::Dims{N}, mesh +) where {N} + if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal + return ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(mesh)) + elseif sharding.type == OpShardingType.Other + partitions, num_replicas = get_number_of_ways_dim_sharded(sharding) + @assert length(partitions) == length(shape) + shape = reverse(shape) + + # Calculate indices for each dimension + axis_indices = map(zip(shape, partitions)) do (dim, n_shards) + if n_shards == 1 + [Colon()] + elseif n_shards > 1 + shard_size, remainder = divrem(dim, n_shards) + @assert remainder == 0 "Dimension $dim not evenly divisible by $n_shards \ + shards" + [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)] + else + error("Invalid number of shards: $n_shards") + end + end + + # XXX: Fix performance of this + indices = Vector{NTuple{length(shape),Any}}(undef, length(mesh)) + tile_assignment = sharding.tile_assignment + device_iter = Iterators.Stateful(tile_assignment) + + for idx_tuple in Iterators.product(axis_indices...) + for _ in 1:num_replicas + device = popfirst!(device_iter) + # XXX: incorrect if devices are not contiguous + indices[device + 1] = reverse(idx_tuple) + end + end + + return Tuple(indices) + else + error("Unsupported sharding type: $(sharding.type)") + end +end + +# Function to compute array indices for each device +function compute_array_indices_and_partition_spec( + sharding::CondensedOpSharding, array_size::Dims{N}, mesh +) where {N} + if sharding.type == OpShardingType.Replicated # All devices have the entire array return ( ntuple(Returns(ntuple(i -> 1:array_size[i], N)), length(mesh)), ntuple(Returns(nothing), N), ) - elseif sharding.type == OpShardingType.Maximal - # Maximal: Only one device has the entire array + elseif sharding.type == OpShardingType.Maximal # Only one device has the entire array @assert length(mesh) == 1 return ( ntuple(Returns(ntuple(i -> 1:array_size[i], N)), length(mesh)), ntuple(Returns(nothing), N), ) - elseif sharding.type == OpShardingType.Other - # Other: Tiled sharding - device_list = generate_device_list(sharding) - sorted_mesh_devices = sort(collect(Int64, mesh.device_ids)) - @assert sort(device_list) == sorted_mesh_devices "Mismatched devices list: \ - $(device_list) vs \ - $(mesh.device_ids)" - @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ - Open an issue with an MWE for this case." - @assert isempty(sharding.layout_minor_to_major) "Layout transformation is not \ - supported yet!" - - tile_assignment = reshape(device_list, sharding.tile_assignment_dimensions...) - tile_assignment = permutedims( - tile_assignment, reverse(collect(Int64, 1:ndims(tile_assignment))) - ) - - # Handle replication dimension - tile_dims = size(tile_assignment)[(1 + sharding.replicate_on_last_tile_dim):end] - - # Calculate tile sizes - tile_sizes = div.(array_size, tile_dims) - mesh_devices = reshape([mesh.device_ids...], mesh.shape) + elseif sharding.type == OpShardingType.Other # Tiled sharding + tile_dims, _ = get_number_of_ways_dim_sharded(sharding) + mesh_devices = Reactant.Sharding.device_ids(mesh) # Match dimensions to mesh axes used_axes = Set{Int}() partition_spec = ntuple(N) do dim if dim <= length(tile_dims) && tile_dims[dim] > 1 tile_seq = __get_device_sequence( - tile_assignment, dim + sharding.replicate_on_last_tile_dim + sharding.tile_assignment, dim + sharding.replicate_on_last_tile_dim ) for (axis_idx, axis_name) in enumerate(mesh.axis_names) @@ -190,18 +284,9 @@ function compute_array_indices_and_partition_spec( return nothing end - device_to_array_indices = map(mesh.device_ids) do device_id - tile_index = findfirst(==(device_id), tile_assignment) - if tile_index === nothing - error("Device ID $device_id not found in tile assignment \ - $tile_assignment") - end - index_tup = Tuple(tile_index.I) - sharding.replicate_on_last_tile_dim && (index_tup = index_tup[2:end]) - tile_start = (index_tup .- 1) .* tile_sizes .+ 1 - tile_end = index_tup .* tile_sizes - return ntuple(i -> tile_start[i]:tile_end[i], N) - end + device_to_array_indices = sharding_to_concrete_array_indices( + sharding, array_size, mesh + ) return device_to_array_indices, partition_spec else @@ -211,9 +296,7 @@ end # Helper function to get device sequence along a dimension function __get_device_sequence(arr, dim) - # Take first index for all other dimensions idx = ones(Int, ndims(arr)) - # Get sequence along target dimension sequence = Int[] for i in 1:size(arr, dim) idx[dim] = i diff --git a/test/sharding.jl b/test/sharding.jl index 6f6b498027..a05f6efbb9 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -78,18 +78,33 @@ fn_test2(x) = x .+ x' @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) end - samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data")) - w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing)) - - samples = reshape(collect(Float32, 1:84), 7, 12) - w1 = reshape(collect(Float32, 1:28), 4, 7) + samples = reshape(collect(Float32, 1:48), 4, 12) + w1 = reshape(collect(Float32, 1:16), 4, 4) w2 = reshape(collect(Float32, 1:32), 8, 4) - samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) - w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) - w2_ra = Reactant.to_rarray(w2) - - if !fake_run - @test Array(@jit(predict(samples_ra, w1_ra, w2_ra))) ≈ predict(samples, w1, w2) + for (samples_sharding, w1_sharding, w2_sharding) in zip( + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, ("model", nothing)), + Sharding.NamedSharding(mesh, (nothing, "data")), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, (nothing, "data")), + Sharding.NoSharding(), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NoSharding(), + Sharding.NoSharding(), + ), + ) + samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) + w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) + w2_ra = Reactant.to_rarray(w2; sharding=w2_sharding) + + if !fake_run + @test Array(@jit(predict(samples_ra, w1_ra, w2_ra))) ≈ predict(samples, w1, w2) + end end end From d1a26d292b4b462d97f76a2a6b8e4fe3e072fe13 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Feb 2025 10:53:45 -0600 Subject: [PATCH 10/14] refactor: roundtrip from XLA to get mhlo shardings --- src/Compiler.jl | 20 +++--- src/Ops.jl | 2 - src/Sharding.jl | 149 +++++++++++++++++++++++++++----------------- src/TracedUtils.jl | 99 +++++++---------------------- src/xla/Sharding.jl | 58 +++++++++-------- test/sharding.jl | 55 ++++++++++++++++ 6 files changed, 211 insertions(+), 172 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index f21e689dfd..aa970a6d48 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1080,15 +1080,12 @@ function codegen_flatten!( linear_parameter_shardings[i] ) if Reactant.Sharding.is_sharded(carg) - # Check if the sharding provided is same as the one we have - arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding( - Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg)) - ) - - # XXX: Change to error - if arg_condensed_op_sharding != condensed_op_sharding - @warn "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." - end + # Currently disabling the error since we roundtrip from MHLO to generate + # the shardings + # # Check if the sharding provided is same as the one we have + # arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding( + # Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg)) + # ) # @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." push!(flatten_code, :($usbuf = $flatcode.data)) @@ -1522,6 +1519,11 @@ function compile(f, args; sync=false, kwargs...) linear_result_shard_info = if mlir_fn_res.is_sharded output_shardings = XLA.get_output_shardings(exec) + # XXX: remove + for (i, sd) in enumerate(output_shardings) + @info("i: $i \t", sd) + end + # XXX: remove XLA.compute_array_indices_and_partition_spec.( output_shardings, size.(mlir_fn_res.linear_results), diff --git a/src/Ops.jl b/src/Ops.jl index 803732f62f..c1b4ca8a6c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1960,8 +1960,6 @@ end tr else if typeof(tr) != typeof(fr) - @show tr.mlir_data - @show fr.mlir_data @assert typeof(tr) == typeof(fr) "$(typeof(tr)) vs $(typeof(fr))" end tr diff --git a/src/Sharding.jl b/src/Sharding.jl index cccd4e3b9a..a9ec11eab9 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -1,6 +1,6 @@ module Sharding -using ..Reactant: Reactant, XLA +using ..Reactant: Reactant, XLA, MLIR struct Mesh{D,ND} device_ids::NTuple{ND,Int} @@ -33,12 +33,7 @@ end Base.vec(mesh::Mesh) = vec(device_ids(mesh)) -function device_ids(mesh::Mesh) - # XXX: Do we need to permute the device ids? - return permutedims( - reshape(collect(Int64, mesh.device_ids), size(mesh)...), reverse(1:ndims(mesh)) - ) -end +device_ids(mesh::Mesh) = reshape(collect(Int64, mesh.device_ids), size(mesh)...) Base.length(::Mesh{D,ND}) where {D,ND} = ND Base.ndims(::Mesh{D}) where {D} = D @@ -125,67 +120,92 @@ function (sharding::NamedSharding)( return data, ShardInfo(sharding, device_to_array_slices) end +function get_shardy_tensor_sharding_attribute( + ctx, N::Int, sharding::NamedSharding, mesh_name; do_transpose=true +) + dimension_sharding_attrs = Vector{MLIR.API.MlirAttribute}(undef, N) + for (j, name) in enumerate(sharding.partition_spec) + if name === nothing + axes = MLIR.IR.Attribute[] + else + @assert name isa Symbol + axes = [ + MLIR.API.sdyAxisRefAttrGet( + ctx, String(name), MLIR.API.MlirAttribute(C_NULL) + ), + ] + end + dimension_sharding_attrs[j] = MLIR.API.sdyDimensionShardingAttrGet( + ctx, length(axes), axes, sharding.is_closed[j], sharding.priority[j] + ) + end + + return MLIR.IR.Attribute( + MLIR.API.sdyTensorShardingAttrGet( + ctx, + mesh_name, + length(dimension_sharding_attrs), + do_transpose ? reverse(dimension_sharding_attrs) : dimension_sharding_attrs, + 0, + MLIR.API.MlirAttribute[], + ), + ) +end + +# An internal abstraction to allow defining `convert` to XLA sharding struct ShardingWithShape{S,D} <: AbstractSharding sharding::S shape::D end -# XXX: we need to make the performance of this function better -function XLA.CondensedOpSharding(sharding_and_shape::ShardingWithShape{<:NamedSharding}) - (; sharding, shape) = sharding_and_shape - (; mesh, partition_spec) = sharding - @assert length(partition_spec) == length(shape) - - partition_spec = reverse(partition_spec) - shape = reverse(shape) - - array_mapping = __get_array_mapping(partition_spec) - mesh_axis_position = Dict(name => i for (i, name) in enumerate(mesh.axis_names)) - - replicated_mesh_axes = Tuple{Int64,Int64}[] - for (i, axis_name) in enumerate(mesh.axis_names) - if !haskey(array_mapping, axis_name) - push!(replicated_mesh_axes, (i, size(mesh, axis_name))) - end - end - - tile_assignment = device_ids(mesh) - - # Fast Path for replicating the input across all devices - if length(replicated_mesh_axes) == ndims(mesh) - return XLA.CondensedOpSharding{ndims(tile_assignment)}( - XLA.OpShardingType.Replicated, false, tile_assignment - ) - end +internal_simple_op(x) = Reactant.Ops.negate(x) - # Calculate new mesh shape and permutation - mesh_permutation = Int[] - new_mesh_shape = ones(Int, length(shape)) +# XXX: We do a fake compile here to get the mhlo sharding. Ideally we should be able to use +# some API to convert shardy annotations to mhlo annotations. +# XXX: We should cache the CondensedOpSharding else we will end up calling this function +# multiple times. +function XLA.CondensedOpSharding(sharding_and_shape::ShardingWithShape{<:NamedSharding}) + tmp = Reactant.ConcreteRArray( + ones(sharding_and_shape.shape); sharding=LazySharding(sharding_and_shape.sharding) + ) + _, exec, _, _, _ = Reactant.Compiler.compile_xla(internal_simple_op, (tmp,)) + return XLA.CondensedOpSharding(only(XLA.get_parameter_shardings(exec))) +end - # Sort array mapping by position to ensure consistent order - for (name, pos) in sort(collect(array_mapping); by=x -> x[2]) - new_mesh_shape[pos] *= size(mesh, name) - push!(mesh_permutation, mesh_axis_position[name]) - end +# Lazy Sharding. ConcreteArrays with this annotation is not really sharded but we can use it +# to compile the executable. +struct LazySharding{S} <: AbstractSharding + sharding::S +end - # Handle replicated dimensions at the end - replicate_on_last_tile_dim = false - if !isempty(replicated_mesh_axes) - replicated_size = prod(last(axis) for axis in replicated_mesh_axes) - push!(new_mesh_shape, replicated_size) - append!(mesh_permutation, first.(replicated_mesh_axes)) +function get_shardy_tensor_sharding_attribute( + ctx, N::Int, sharding::LazySharding, mesh_name; do_transpose=true +) + return get_shardy_tensor_sharding_attribute( + ctx, N, sharding.sharding, mesh_name; do_transpose + ) +end - tile_assignment = reshape(tile_assignment, new_mesh_shape...) - push!(mesh_permutation, length(mesh_permutation) + 1) - replicate_on_last_tile_dim = true - end +function (sharding::LazySharding)( + client::XLA.Client, ::Nothing, x::Union{AbstractArray,Number} +) + data = XLA.AsyncBuffer( + XLA.ArrayFromHostBuffer( + client, + x, + XLA.ClientGetAddressableDevice( + client, XLA.device_ordinal(client, vec(sharding.sharding.mesh)[1]) + ), + ), + nothing, + ) - permuted = permutedims(tile_assignment, mesh_permutation) - final_assignment = reshape(permuted, new_mesh_shape...) + return (data,), ShardInfo(sharding, (ntuple(i -> 1:size(x, i), ndims(x)),)) +end - return XLA.CondensedOpSharding{ndims(final_assignment)}( - XLA.OpShardingType.Other, replicate_on_last_tile_dim, final_assignment - ) +function Base.getproperty(sharding::LazySharding, name::Symbol) + name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) + return getproperty(sharding.sharding, name) end # Given Sharding + Array --> ShardInfo @@ -202,7 +222,19 @@ end function Base.getproperty(sharding::ShardInfo, name::Symbol) name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) - return getfield(sharding.sharding, name) + return getproperty(sharding.sharding, name) +end + +function get_shardy_tensor_sharding_attribute( + ctx, sharding::ShardInfo, mesh_name; do_transpose=true +) + return get_shardy_tensor_sharding_attribute( + ctx, + length(first(sharding.device_to_array_slices)), + sharding.sharding, + mesh_name; + do_transpose, + ) end function (sharding::ShardInfo)(client::XLA.Client, device, x::Union{AbstractArray,Number}) @@ -221,6 +253,7 @@ Checks whether the given sharding refers to no sharding. """ is_sharded(::NoSharding) = false is_sharded(::NamedSharding) = true +is_sharded(s::LazySharding) = is_sharded(s.sharding) is_sharded(s::ShardInfo) = is_sharded(s.sharding) function is_sharded(x::AbstractArray) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4f0e20a449..b22e61dc30 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -229,27 +229,12 @@ function make_mlir_fn( is_sharded = true traced_args_to_shardings[v] = k.sharding if !haskey(mesh_cache, k.sharding.mesh) - mesh_op_attrs = Reactant.Ops.mesh(mod, k.sharding.mesh) - mesh_cache[k.sharding.mesh] = mesh_op_attrs + mesh_cache[k.sharding.mesh] = Reactant.Ops.mesh(mod, k.sharding.mesh) end end end end - if is_sharded - unique_meshes = unique([m.mesh for (k, m) in traced_args_to_shardings]) - # TODO: support multiple meshes - @assert length(unique_meshes) == 1 "Currently we support using a single mesh" - # sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes] - # @assert allequal(sorted_devices) "All meshes must have the same device ids" - # num_partitions = length(first(sorted_devices)) - sharding_mesh = first(unique_meshes) - mesh_op_attrs = mesh_cache[sharding_mesh] - num_partitions = length(sharding_mesh) - else - sharding_mesh = nothing - end - func = MLIR.IR.block!(MLIR.IR.body(mod)) do return MLIR.Dialects.func.func_(; sym_name=name * "_tmp", @@ -261,59 +246,6 @@ function make_mlir_fn( fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) push!(MLIR.IR.region(func, 1), fnbody) - if is_sharded - # Here we construct tensor sharding annotations for the function arguments - linear_arg_shardings = Vector{MLIR.IR.Attribute}(undef, length(linear_args)) - for (i, arg) in enumerate(linear_args) - if haskey(traced_args_to_shardings, arg) - if ndims(arg) == 0 - throw( - ErrorException( - "Sharding annotations are not supported for scalar arguments" - ), - ) - end - sharding = traced_args_to_shardings[arg] - mesh_op_attrs = mesh_cache[sharding.mesh] - @assert length(sharding.partition_spec) == ndims(arg) - - dimension_sharding_attrs = Vector{MLIR.API.MlirAttribute}(undef, ndims(arg)) - for (j, name) in enumerate(sharding.partition_spec) - if name === nothing - axes = MLIR.IR.Attribute[] - else - @assert name isa Symbol - axes = [ - MLIR.API.sdyAxisRefAttrGet( - ctx, String(name), MLIR.API.MlirAttribute(C_NULL) - ), - ] - end - dimension_sharding_attrs[j] = MLIR.API.sdyDimensionShardingAttrGet( - ctx, length(axes), axes, sharding.is_closed[j], sharding.priority[j] - ) - end - - # Currently we don't support replicated axes from user input, we do - # implicitly via shardy - linear_arg_shardings[i] = MLIR.IR.Attribute( - MLIR.API.sdyTensorShardingAttrGet( - ctx, - mesh_op_attrs.sym_name, - length(dimension_sharding_attrs), - if do_transpose - reverse(dimension_sharding_attrs) - else - dimension_sharding_attrs - end, - 0, - MLIR.API.MlirAttribute[], - ), - ) - end - end - end - @assert MLIR.IR._has_block() # Explicitly don't use block! to avoid creating a closure, which creates @@ -412,9 +344,26 @@ function make_mlir_fn( MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) if is_sharded + unique_meshes = unique([m.mesh for (k, m) in traced_args_to_shardings]) + + # TODO: support multiple meshes + if length(unique_meshes) > 1 + error("Currently we support using a single mesh") + sorted_devices = [sort(vec(m)) for m in unique_meshes] + @assert allequal(sorted_devices) "All meshes must have the same device ids" + end + sharding_mesh = first(unique_meshes) + num_partitions = length(sharding_mesh) + + linear_arg_shardings = Vector{MLIR.IR.Attribute}(undef, length(linear_args)) + # Attach `sdy.sharding` attribute to the argument for (i, arg) in enumerate(linear_args) if haskey(traced_args_to_shardings, arg) + sharding = traced_args_to_shardings[arg] + linear_arg_shardings[i] = Reactant.Sharding.get_shardy_tensor_sharding_attribute( + ctx, sharding, mesh_cache[sharding.mesh].sym_name; do_transpose + ) MLIR.API.mlirFuncSetArgAttr( func2, i - 1, "sdy.sharding", linear_arg_shardings[i] ) @@ -426,20 +375,16 @@ function make_mlir_fn( for i in mutated_args arg = linear_args[i] if has_residx(arg) && haskey(traced_args_to_shardings, arg) - residx = -1 - for (j, res) in enumerate(linear_results) - if res === arg - residx = j - break - end - end - @assert residx > 0 + residx = findfirst(Base.Fix1(===, arg), linear_results) + @assert residx !== nothing result_not_replicated[residx] = true MLIR.API.mlirFuncSetResultAttr( func2, residx - 1, "sdy.sharding", linear_arg_shardings[i] ) end end + else + sharding_mesh = nothing end MLIR.API.mlirOperationDestroy(func.operation) diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index bd1f0e9716..1426c595d2 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -110,18 +110,26 @@ end function generate_device_list(sharding::OpSharding) if !isempty(sharding.iota_reshape_dims) # Generate device IDs using iota - num_devices = prod(sharding.iota_reshape_dims) - iota_devices = collect( - Int64, reshape(0:(num_devices - 1), sharding.iota_reshape_dims...) - ) + num_devices = prod(sharding.tile_assignment_dimensions) # Permute the iota array if iota_transpose_perm is provided + # We need to ensure that we account for the col-major ordering in julia. See the + # unit tests for examples. if !isempty(sharding.iota_transpose_perm) - iota_devices = permutedims(iota_devices, Tuple(sharding.iota_transpose_perm)) + # XXX: Simplify the permutedims + iota_devices = collect( + Int64, reshape(0:(num_devices - 1), reverse(sharding.iota_reshape_dims)...) + ) + + iota_devices = permutedims(iota_devices, reverse(1:ndims(iota_devices))) + iota_devices = permutedims(iota_devices, sharding.iota_transpose_perm) + iota_devices = permutedims(iota_devices, reverse(1:ndims(iota_devices))) + + return vec(iota_devices) + else + @assert num_devices == prod(sharding.iota_reshape_dims) + return collect(0:(num_devices - 1)) end - - # Flatten the permuted iota array to get tile_assignment_devices - return vec(iota_devices) end return sharding.tile_assignment_devices end @@ -167,6 +175,8 @@ function Base.:(==)(a::CondensedOpSharding, b::CondensedOpSharding) end function CondensedOpSharding(sharding::OpSharding) + @show sharding + @assert isempty(sharding.last_tile_dims) "Last Tile dimensions are not supported \ yet!" @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ @@ -213,32 +223,28 @@ function sharding_to_concrete_array_indices( # Calculate indices for each dimension axis_indices = map(zip(shape, partitions)) do (dim, n_shards) - if n_shards == 1 - [Colon()] - elseif n_shards > 1 - shard_size, remainder = divrem(dim, n_shards) - @assert remainder == 0 "Dimension $dim not evenly divisible by $n_shards \ - shards" - [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)] - else - error("Invalid number of shards: $n_shards") - end + @assert dim > 0 "Invalid dimension: $dim" + @assert n_shards > 0 "Invalid number of shards: $n_shards" + n_shards == 1 && return [1:dim] + shard_size, remainder = divrem(dim, n_shards) + @assert remainder == 0 "Dimension $dim not evenly divisible by $n_shards shards" + return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)] end - # XXX: Fix performance of this - indices = Vector{NTuple{length(shape),Any}}(undef, length(mesh)) - tile_assignment = sharding.tile_assignment - device_iter = Iterators.Stateful(tile_assignment) + @show vec(sharding.tile_assignment) + indices = Dict{Int,NTuple{N,UnitRange{Int}}}() + device_idx = 1 for idx_tuple in Iterators.product(axis_indices...) for _ in 1:num_replicas - device = popfirst!(device_iter) - # XXX: incorrect if devices are not contiguous - indices[device + 1] = reverse(idx_tuple) + indices[sharding.tile_assignment[device_idx]] = reverse(idx_tuple) + device_idx += 1 end end - return Tuple(indices) + @show sort(collect(indices); by=x -> x[1]) + + return map(Base.Fix1(getindex, indices), mesh.device_ids) else error("Unsupported sharding type: $(sharding.type)") end diff --git a/test/sharding.jl b/test/sharding.jl index a05f6efbb9..50fb37012c 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -108,3 +108,58 @@ fn_test2(x) = x .+ x' end end end + +# Tests from the examples in +# https://github.com/openxla/xla/blob/96d6678053d867099a42be9001c49b2ed7111afd/xla/hlo/ir/tile_assignment.h#L53-L68 +@testset "Device List from Iota Tile" begin + @test Reactant.XLA.generate_device_list( + Reactant.XLA.OpSharding( + Reactant.XLA.OpShardingType.Other, + Int64[], + Int64[], + true, + Reactant.XLA.OpShardingType.T[], + [4, 4, 1], + Int64[], + [4, 2, 2], + Int32[1, 2, 3], + false, + 0, + Reactant.XLA.ShardGroupType.As, + ), + ) == collect(0:15) + + @test Reactant.XLA.generate_device_list( + Reactant.XLA.OpSharding( + Reactant.XLA.OpShardingType.Other, + Int64[], + Int64[], + true, + Reactant.XLA.OpShardingType.T[], + [4, 4, 1], + Int64[], + [4, 2, 2], + Int32[2, 1, 3], + false, + 0, + Reactant.XLA.ShardGroupType.As, + ), + ) == [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15] + + @test Reactant.XLA.generate_device_list( + Reactant.XLA.OpSharding( + Reactant.XLA.OpShardingType.Other, + Int64[], + Int64[], + true, + Reactant.XLA.OpShardingType.T[], + [2, 4], + Int64[], + [4, 2], + Int32[2, 1], + false, + 0, + Reactant.XLA.ShardGroupType.As, + ), + ) == [0, 2, 4, 6, 1, 3, 5, 7] +end From 1e95760d57594e13083e898c3b1dffc3a6b9836e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Feb 2025 11:40:08 -0600 Subject: [PATCH 11/14] fix: ordering of the tile_assignment --- src/Compiler.jl | 5 ----- src/Ops.jl | 5 +---- src/xla/Sharding.jl | 18 ++++++++---------- test/sharding.jl | 4 ++++ 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index aa970a6d48..a9931fa475 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1519,11 +1519,6 @@ function compile(f, args; sync=false, kwargs...) linear_result_shard_info = if mlir_fn_res.is_sharded output_shardings = XLA.get_output_shardings(exec) - # XXX: remove - for (i, sd) in enumerate(output_shardings) - @info("i: $i \t", sd) - end - # XXX: remove XLA.compute_array_indices_and_partition_spec.( output_shardings, size.(mlir_fn_res.linear_results), diff --git a/src/Ops.jl b/src/Ops.jl index c1b4ca8a6c..11c75820a8 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2139,10 +2139,7 @@ end location=mlir_stacktrace("mesh", @__FILE__, @__LINE__), ) return mesh( - mod, - [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], - vec(m); - location, + mod, [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], vec(m); location ) end diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 1426c595d2..d222e8b92d 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -175,8 +175,6 @@ function Base.:(==)(a::CondensedOpSharding, b::CondensedOpSharding) end function CondensedOpSharding(sharding::OpSharding) - @show sharding - @assert isempty(sharding.last_tile_dims) "Last Tile dimensions are not supported \ yet!" @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ @@ -187,8 +185,12 @@ function CondensedOpSharding(sharding::OpSharding) if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal tile_assignment = generate_device_list(sharding) elseif sharding.type == OpShardingType.Other - tile_assignment = reshape( - generate_device_list(sharding), sharding.tile_assignment_dimensions... + tile_assignment = permutedims( + reshape( + generate_device_list(sharding), + reverse(sharding.tile_assignment_dimensions)..., + ), + reverse(1:length(sharding.tile_assignment_dimensions)), ) else error("Invalid sharding type: $(sharding.type)") @@ -231,19 +233,15 @@ function sharding_to_concrete_array_indices( return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)] end - @show vec(sharding.tile_assignment) - indices = Dict{Int,NTuple{N,UnitRange{Int}}}() device_idx = 1 - for idx_tuple in Iterators.product(axis_indices...) - for _ in 1:num_replicas + for _ in 1:num_replicas + for idx_tuple in Iterators.product(axis_indices...) indices[sharding.tile_assignment[device_idx]] = reverse(idx_tuple) device_idx += 1 end end - @show sort(collect(indices); by=x -> x[1]) - return map(Base.Fix1(getindex, indices), mesh.device_ids) else error("Unsupported sharding type: $(sharding.type)") diff --git a/test/sharding.jl b/test/sharding.jl index 50fb37012c..473b4d0e58 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -60,6 +60,8 @@ predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) fn_test2(x) = x .+ x' +fn_test3(x) = sum(x; dims=1) + @testset "Sharding Across 8 Devices" begin if length(addressable_devices) ≥ 8 mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("data", "model")) @@ -76,6 +78,8 @@ fn_test2(x) = x .+ x' if !fake_run @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) + + @test Array(@jit fn_test3(x_ra)) ≈ fn_test3(x) end samples = reshape(collect(Float32, 1:48), 4, 12) From 3dbb28251f31e7d4e5611a87531a8279fa95cdf8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Feb 2025 11:43:50 -0600 Subject: [PATCH 12/14] chore: remove unused function --- src/Sharding.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/Sharding.jl b/src/Sharding.jl index a9ec11eab9..6f9bd3c746 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -266,16 +266,4 @@ function is_sharded(x::Number) return false end -function __get_array_mapping(partition_spec) - mapping = Dict{Symbol,Int64}() - for (i, axis) in enumerate(partition_spec) - axis === nothing && continue - axis isa Symbol && (axis = (axis,)) - for axis_name in axis - mapping[axis_name] = i - end - end - return mapping -end - end From f0d7485a7d04fad2cc8625d9244e929f1e8f1a4f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Feb 2025 14:26:11 -0600 Subject: [PATCH 13/14] chore: bump jll --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4c6062e94c..8d11b02ad3 100644 --- a/Project.toml +++ b/Project.toml @@ -81,7 +81,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.5" -Reactant_jll = "0.0.68" +Reactant_jll = "0.0.69" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" From cdbc466b526f4ad93f933a2fa5a7ba1a747712e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Feb 2025 14:45:11 -0600 Subject: [PATCH 14/14] test: only run tests if correct number of devices are present --- .github/workflows/downgrade.yml | 1 + test/sharding.jl | 108 ++++++++++++++------------------ 2 files changed, 48 insertions(+), 61 deletions(-) diff --git a/.github/workflows/downgrade.yml b/.github/workflows/downgrade.yml index 49cda9dbaf..458e50d0cf 100644 --- a/.github/workflows/downgrade.yml +++ b/.github/workflows/downgrade.yml @@ -62,6 +62,7 @@ jobs: env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager REACTANT_TEST_GROUP: ${{ matrix.test_group }} + XLA_FLAGS: "--xla_force_host_platform_device_count=8" - uses: julia-actions/julia-processcoverage@v1 if: steps.run_tests.outcome == 'success' - uses: codecov/codecov-action@v5 diff --git a/test/sharding.jl b/test/sharding.jl index 473b4d0e58..a400af6d65 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -13,37 +13,29 @@ end @testset "Sharding Across 2 Devices" begin if length(addressable_devices) ≥ 2 mesh = Sharding.Mesh([0 1;], ("x", "y")) - fake_run = false - else - @warn "Not enough addressable devices to run sharding tests; we are running a \ - pretend test for testing purposes" - mesh = Sharding.Mesh(reshape([0], 1, 1), ("x", "y")) - fake_run = true - end - data_sharding = Sharding.NamedSharding(mesh, ("y", nothing, "x")) - data_sharding2 = Sharding.NamedSharding(mesh, (nothing, "x", nothing)) - data_sharding3 = Sharding.NamedSharding(mesh, (nothing, nothing, nothing)) # fully replicated data + data_sharding = Sharding.NamedSharding(mesh, ("y", nothing, "x")) + data_sharding2 = Sharding.NamedSharding(mesh, (nothing, "x", nothing)) + data_sharding3 = Sharding.NamedSharding(mesh, (nothing, nothing, nothing)) # fully replicated data - data = reshape(collect(1:(16 * 4 * 12)) ./ (16 * 4 * 12), 16, 4, 12) + data = reshape(collect(1:(16 * 4 * 12)) ./ (16 * 4 * 12), 16, 4, 12) - cdata = Reactant.to_rarray(data) - cdata_sharded = Reactant.to_rarray(data; sharding=data_sharding) - cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) - cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) + cdata = Reactant.to_rarray(data) + cdata_sharded = Reactant.to_rarray(data; sharding=data_sharding) + cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) + cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) - @test data ≈ - Array(cdata) ≈ - Array(cdata_sharded) ≈ - Array(cdata_sharded2) ≈ - Array(cdata_sharded3) + @test data ≈ + Array(cdata) ≈ + Array(cdata_sharded) ≈ + Array(cdata_sharded2) ≈ + Array(cdata_sharded3) - @test cdata_sharded.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} - @test cdata_sharded2.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} - @test cdata_sharded3.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} - @test cdata.sharding isa Sharding.NoShardInfo + @test cdata_sharded.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata_sharded2.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata_sharded3.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata.sharding isa Sharding.NoShardInfo - if !fake_run true_res_y, true_res_x, true_res_z = fn_test1(data) for cd in (cdata, cdata_sharded, cdata_sharded2, cdata_sharded3) @@ -53,6 +45,8 @@ end @test Array(res_z) ≈ true_res_z @test Array(res_x) ≈ true_res_x end + else + @warn "Not enough addressable devices to run sharding tests" end end @@ -65,49 +59,41 @@ fn_test3(x) = sum(x; dims=1) @testset "Sharding Across 8 Devices" begin if length(addressable_devices) ≥ 8 mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("data", "model")) - fake_run = false - else - @warn "Not enough addressable devices to run sharding tests; we are running a \ - pretend test for testing purposes" - mesh = Sharding.Mesh(reshape([0], 1, 1), ("data", "model")) - fake_run = true - end - x = reshape(collect(Float32, 1:16), 4, 4) - x_ra = Reactant.to_rarray(x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))) + x = reshape(collect(Float32, 1:16), 4, 4) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) - if !fake_run @test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x) @test Array(@jit fn_test3(x_ra)) ≈ fn_test3(x) - end - - samples = reshape(collect(Float32, 1:48), 4, 12) - w1 = reshape(collect(Float32, 1:16), 4, 4) - w2 = reshape(collect(Float32, 1:32), 8, 4) - for (samples_sharding, w1_sharding, w2_sharding) in zip( - ( - Sharding.NamedSharding(mesh, ("model", "data")), - Sharding.NamedSharding(mesh, ("model", nothing)), - Sharding.NamedSharding(mesh, (nothing, "data")), - ), - ( - Sharding.NamedSharding(mesh, ("model", "data")), - Sharding.NamedSharding(mesh, (nothing, "data")), - Sharding.NoSharding(), - ), - ( - Sharding.NamedSharding(mesh, ("model", "data")), - Sharding.NoSharding(), - Sharding.NoSharding(), - ), - ) - samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) - w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) - w2_ra = Reactant.to_rarray(w2; sharding=w2_sharding) + samples = reshape(collect(Float32, 1:48), 4, 12) + w1 = reshape(collect(Float32, 1:16), 4, 4) + w2 = reshape(collect(Float32, 1:32), 8, 4) + + for (samples_sharding, w1_sharding, w2_sharding) in zip( + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, ("model", nothing)), + Sharding.NamedSharding(mesh, (nothing, "data")), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NamedSharding(mesh, (nothing, "data")), + Sharding.NoSharding(), + ), + ( + Sharding.NamedSharding(mesh, ("model", "data")), + Sharding.NoSharding(), + Sharding.NoSharding(), + ), + ) + samples_ra = Reactant.to_rarray(samples; sharding=samples_sharding) + w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding) + w2_ra = Reactant.to_rarray(w2; sharding=w2_sharding) - if !fake_run @test Array(@jit(predict(samples_ra, w1_ra, w2_ra))) ≈ predict(samples, w1, w2) end end