diff --git a/Project.toml b/Project.toml index b852cc5f37..aca1e78d99 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.2.26" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -61,6 +62,7 @@ ArrayInterface = "7.17.1" CEnum = "0.5" CUDA = "5.6" Downloads = "1.6" +EnumX = "1" Enzyme = "0.13.28" EnzymeCore = "0.8.8" Functors = "0.5" @@ -79,7 +81,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.5" -Reactant_jll = "0.0.64" +Reactant_jll = "0.0.66" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" diff --git a/docs/src/api/api.md b/docs/src/api/api.md index b7befc9df0..ce61e62322 100644 --- a/docs/src/api/api.md +++ b/docs/src/api/api.md @@ -25,6 +25,7 @@ within_compile ```@docs @code_hlo +@code_mhlo ``` ## Profile XLA diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1a8de8086d..0af6676efc 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1067,9 +1067,8 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( T = eltype(A) N = ndims(A) if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive - if sharding isa Reactant.Sharding.NoSharding || - sharding isa Reactant.Sharding.FinalizedNoSharding - return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.FinalizedNoSharding} + if !Sharding.is_sharded(sharding) + return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.NoShardInfo} else error("TODO: implement sharding") end diff --git a/src/Compiler.jl b/src/Compiler.jl index fa69d58060..adad293643 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -36,16 +36,7 @@ end @nospecialize(obj::AbstractArray{T}), field, val ) where {T} ancestor_obj = ancestor(obj) - if isbitstype(T) || ancestor_obj isa RArray - if val isa XLA.AsyncBuffer - if Reactant.Sharding.is_sharded(ancestor_obj) - error("`val` can't be a buffer if `obj` is sharded") - else - return Base.setfield!(obj, field, (val,)) - end - end - return Base.setfield!(obj, field, val) - end + (isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val) return Base.setindex!(obj, val, field) end @@ -75,29 +66,39 @@ function create_result( return Expr(:new, T, elems...) end +function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh) + device_to_array_slices, partition_spec = path_to_shard_info[path] + delete!(path_to_shard_info, path) + sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec) + return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices) +end + function create_result( - tocopy::ConcreteRNumber{T}, path, result_stores, path_to_shard_info, sharding_mesh -) where {T} + tocopy::ConcreteRNumber{T,D,S}, path, result_stores, path_to_shard_info, sharding_mesh +) where {T,D,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - return :(ConcreteRNumber{$T}($restore)) + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh) + return :(ConcreteRNumber{$T,length($(restore)),$(typeof(sharding))}( + ($(restore)...,), $sharding + )) + else + return :(ConcreteRNumber{$T}($restore)) + end + end + + if path_to_shard_info !== nothing # restore sharding + sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh) + return :(ConcreteRNumber{$T,length($(tocopy.data)),$(typeof(sharding))}( + ($(tocopy.data...,)), $sharding + )) end # We will set the data for this later return :(ConcreteRNumber{$T}($(tocopy.data))) end -function __construct_sharding_for_carray( - ::ConcreteRArray{T,N,D,S}, path, _, path_to_shard_info, sharding_mesh -) where {T,N,D,S} - device_to_array_slices, partition_spec = path_to_shard_info[path] - delete!(path_to_shard_info, path) - sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec) - return Reactant.Sharding.FinalizedNamedSharding{typeof(sharding),ndims(sharding_mesh)}( - sharding, device_to_array_slices - ) -end - function create_result( tocopy::ConcreteRArray{T,N,D,S}, path, result_stores, path_to_shard_info, sharding_mesh ) where {T,N,D,S} @@ -105,10 +106,8 @@ function create_result( restore = result_stores[path] delete!(result_stores, path) if path_to_shard_info !== nothing # restore sharding - sharding = __construct_sharding_for_carray( - tocopy, path, result_stores, path_to_shard_info, sharding_mesh - ) - return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}( + sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh) + return :(ConcreteRArray{$T,$N,length($(restore)),$(typeof(sharding))}( ($(restore)...,), $(tocopy.shape), $sharding )) else @@ -117,10 +116,8 @@ function create_result( end if path_to_shard_info !== nothing # restore sharding - sharding = __construct_sharding_for_carray( - tocopy, path, result_stores, path_to_shard_info, sharding_mesh - ) - return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}( + sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh) + return :(ConcreteRArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}( ($(tocopy.data)...,), $(tocopy.shape), $sharding )) end @@ -365,6 +362,7 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo "binary_op_transpose_simplify_or", "binary_op_transpose_simplify_and", "binary_op_transpose_simplify_xor", + "associative_binary_op_reordering<1>", "transpose_unary_transpose_abs", "transpose_unary_transpose_neg", "transpose_unary_transpose_sqrt", @@ -380,12 +378,15 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo "transpose_unary_transpose_sine", "transpose_unary_transpose_tanh", "transpose_broadcast_in_dim_to_broadcast_in_dim<16>", + "scatter_indices_are_unique", + "transpose_reduce_simplify", "replace_neg_add_with_subtract", "log_const_prop<1>", "log_plus_one_const_prop<1>", "binop_const_simplify", "transpose_broadcast_in_dim_to_broadcast_in_dim", "not_select_simplify", + "scatter_update_computation_const_prop", "common_compare_expression_rewrite", "compare_select_simplify", "while_simplify<1>", @@ -794,10 +795,12 @@ function compile_mlir!( results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)] nresults = MLIR.IR.Value[] linear_results2 = TracedType[] + results_mask = falses(length(results)) for (i, op) in enumerate(results) if !MLIR.IR.is_block_arg(op) push!(nresults, op) push!(linear_results2, linear_results[i]) + results_mask[i] = true continue end push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op))) @@ -812,11 +815,18 @@ function compile_mlir!( out_tys2 = [MLIR.IR.type(a) for a in nresults] + res_attrs = MLIR.IR.attr(compiled_f, "res_attrs") + if res_attrs isa MLIR.IR.Attribute + res_attrs = [ + res_attrs[i - 1] for (i, present) in enumerate(results_mask) if present + ] + end + func3 = MLIR.Dialects.func.func_(; sym_name="main", function_type=MLIR.IR.FunctionType(in_tys, out_tys2), arg_attrs=MLIR.IR.attr(compiled_f, "arg_attrs"), - res_attrs=MLIR.IR.attr(compiled_f, "res_attrs"), + res_attrs, no_inline=MLIR.IR.attr(compiled_f, "no_inline"), body=MLIR.IR.Region(), ) @@ -837,7 +847,6 @@ function compile_mlir!( linear_args, in_tys, linear_results2, - mlir_fn_res.linear_result_shard_info, mlir_fn_res.num_partitions, mlir_fn_res.num_replicas, mlir_fn_res.is_sharded, @@ -862,6 +871,22 @@ macro code_hlo(args...) $(first)($(compiled)))) end +""" + @code_mhlo [optimize = ...] [no_nan = ] f(args...) + +Similar to `@code_hlo`, but prints the module after running the XLA compiler. +""" +macro code_mhlo(args...) + default_options = Dict{Symbol,Any}( + :optimize => true, :no_nan => false, :client => nothing + ) + compile_expr, (; compiled) = compile_call_expr( + __module__, compile_xla, default_options, args... + ) + return esc(:($(compile_expr); + $(first)($(compiled)))) +end + """ @compile [optimize = ...] [no_nan = ] [sync = ] f(args...) """ @@ -998,7 +1023,7 @@ function codegen_flatten!( if is_sharded carg = inv_seen_args[arg] - if carg isa ConcreteRArray && Reactant.Sharding.is_sharded(carg) + if Reactant.Sharding.is_sharded(carg) for j in 1:length(mesh) sbuf = Symbol(:sbuf_, i, "_", j) push!(flatten_names, sbuf) @@ -1007,17 +1032,11 @@ function codegen_flatten!( else # Warn here first and then replicate the input across all devices on the # mesh - if carg isa ConcreteRArray - @warn "Input $carg is not sharded, replicating across all devices. It \ - is recommended to replicate the input across all devices on the \ - mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1 - end + @warn "Input $carg is not sharded, replicating across all devices. It \ + is recommended to replicate the input across all devices on the \ + mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1 buf = Symbol(:buf_, i) - if carg isa ConcreteRArray - push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf)))) - else - push!(flatten_code, :($buf = XLA.synced_buffer($usbuf))) - end + push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf)))) for j in 1:length(mesh) device_id = mesh.device_ids[j] device_ordinal = XLA.device_ordinal(client, device_id) @@ -1030,9 +1049,7 @@ function codegen_flatten!( else sbuf = Symbol(:sbuf_, i) push!(flatten_names, sbuf) - if arg isa TracedRNumber - push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) - elseif arg isa TracedRArray + if arg isa TracedRArray || arg isa TracedRNumber push!(flatten_code, :($sbuf = only(XLA.synced_buffer($usbuf)))) else error("Unsupported type $(typeof(arg))") @@ -1061,7 +1078,6 @@ function codegen_unflatten!( concrete_result, result_stores, path_to_shard_info, - is_sharded::Bool, linear_result_shard_info, sharding_mesh, ) @@ -1369,26 +1385,28 @@ function compile_xla(f, args; client=nothing, kwargs...) mlir_fn_res.is_sharded, ) - mlir_fn_res.num_partitions > 1 && (device = nothing) - # Attach a name, and partitioning attributes to the module __add_mhlo_attributes_and_name!( mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas ) # compile MLIR module to XLA executable - is_sharded = mlir_fn_res.num_partitions > 1 - if is_sharded - # mesh_shape = collect(Int64, size(mlir_fn_res.sharding_mesh)) - mesh_ids = collect(Int64, vec(mlir_fn_res.sharding_mesh.device_ids)) + mlir_fn_res.is_sharded && (device = nothing) + mesh_ids = if mlir_fn_res.is_sharded + collect(Int64, mlir_fn_res.sharding_mesh.device_ids) else - # mesh_shape = Int64[] - mesh_ids = Int64[] + Int64[] end - # exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids, mesh_shape) - exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids) + exec = XLA.Compile( + client, + device, + mod; + num_results=length(mlir_fn_res.linear_results), + mlir_fn_res.is_sharded, + mesh_ids, + ) - return exec, mlir_fn_res, device, client + return mod, exec, mlir_fn_res, device, client finally MLIR.IR.deactivate!(ctx) end @@ -1398,7 +1416,7 @@ function compile_xla(f, args; client=nothing, kwargs...) end function compile(f, args; sync=false, kwargs...) - exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...) + _, exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...) (; linear_args, seen_args, linear_results, preserved_args, concrete_result) = mlir_fn_res @@ -1408,11 +1426,7 @@ function compile(f, args; sync=false, kwargs...) end result_stores = Dict{Tuple,Symbol}() - path_to_shard_info = if mlir_fn_res.is_sharded - Dict{Tuple,Tuple{Array{Vector{UnitRange{Int}}},Tuple}}() - else - nothing - end + path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing # generate Julia `Thunk` code flatten_arg_names, flatten_code = codegen_flatten!( @@ -1431,9 +1445,25 @@ function compile(f, args; sync=false, kwargs...) donated_args_mask, length(linear_results), mlir_fn_res.is_sharded, - mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh.device_ids) : Int64[], + if mlir_fn_res.is_sharded + collect(Int64, mlir_fn_res.sharding_mesh.device_ids) + else + Int64[] + end, ) + linear_result_shard_info = if mlir_fn_res.is_sharded + # Generate a tuple of DeviceToArraySlices and PartitionSpecs + output_shardings = XLA.get_output_shardings(exec) + XLA.compute_array_indices_and_partition_spec.( + output_shardings, + size.(mlir_fn_res.linear_results), + (mlir_fn_res.sharding_mesh,), + ) + else + ntuple(Returns(nothing), length(linear_results)) + end + unflatten_code = codegen_unflatten!( linear_args, preserved_args, @@ -1442,8 +1472,7 @@ function compile(f, args; sync=false, kwargs...) concrete_result, result_stores, path_to_shard_info, - mlir_fn_res.is_sharded, - mlir_fn_res.linear_result_shard_info, + linear_result_shard_info, mlir_fn_res.sharding_mesh, ) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 9da8d745d8..e5bc30e1ce 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -1,7 +1,7 @@ -get_buffer(x::ConcreteRNumber) = x.data.buffer -get_buffer(x::ConcreteRArray{T,0}) where {T} = only(x.data).buffer -function get_buffer(x::ConcreteRArray{T,N}) where {T,N} +function get_buffer(x::Union{ConcreteRArray,ConcreteRNumber}; no_error_for_scalar=false) if Sharding.is_sharded(x.sharding) + # For scalars this is mostly replicated + no_error_for_scalar && return first(x.data).buffer error("`x` is sharded, so `get_buffer` is not defined") end return only(x.data).buffer @@ -21,9 +21,9 @@ Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...) # Ensure the device and client are the same as the input function Base.float(x::ConcreteRNumber{T}) where {T} - client = XLA.client(x.data) - device = XLA.device(x.data) - return ConcreteRNumber(float(T)(to_number(x)); client, device) + return ConcreteRNumber( + float(T)(to_number(x)); client=XLA.client(x), device=XLA.device(x), x.sharding + ) end # written like this to avoid ambiguity errors @@ -37,34 +37,34 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x Base.size(x::ConcreteRArray) = x.shape -Base.isempty(x::ConcreteRNumber) = x.data == XLA.AsyncEmptyBuffer -Base.isempty(x::ConcreteRArray) = any(==(XLA.AsyncEmptyBuffer), x.data) +function Base.isempty(x::Union{ConcreteRArray,ConcreteRNumber}) + return any(==(XLA.AsyncEmptyBuffer), x.data) +end Base.isempty(x::WrappedConcreteRArray) = isempty(ancestor(x)) function Base.convert(::Type{<:Array}, X::ConcreteRArray{T,N}) where {T,N} data = Array{T,N}(undef, size(X)...) XLA.await(X) - if X.sharding isa Sharding.FinalizedNoSharding - buf = only(X.data).buffer - GC.@preserve data buf begin - XLA.BufferToHost(buf, pointer(data)) - end - elseif X.sharding isa Sharding.FinalizedNamedSharding + if Sharding.is_sharded(X) # TODO: We can we much more efficient here and only move data from the minimal # slices that populates the array. for idx in 1:length(X.data) buffer = X.data[idx].buffer # We can't use a pointer to a subarray since BufferToHost expects the data to # be contiguous. - data_slice = data[X.sharding.device_to_array_slices[idx]...] + slice = X.sharding.device_to_array_slices[idx] + data_slice = data[slice...] GC.@preserve data_slice buffer begin XLA.BufferToHost(buffer, pointer(data_slice)) end - data[X.sharding.device_to_array_slices[idx]...] = data_slice + data[slice...] = data_slice end else - error("Unknown sharding type: $(typeof(X.sharding))") + buf = only(X.data).buffer + GC.@preserve data buf begin + XLA.BufferToHost(buf, pointer(data)) + end end return data @@ -80,17 +80,20 @@ function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) return nothing end +to_number(x::Number) = x function to_number(X::ConcreteRScalar{T}) where {T} data = Ref{T}() XLA.await(X) - buf = get_buffer(X) + buf = get_buffer(X; no_error_for_scalar=true) GC.@preserve data buf begin XLA.BufferToHost(buf, data) end return data[] end -Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} = to_number(x) +function Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} + return to_number(x; no_error_for_scalar=true) +end for jlop in (:(Base.abs),), T in (ConcreteRNumber,) @eval $(jlop)(x::$(T)) = $(jlop)(to_number(x)) @@ -121,35 +124,31 @@ for jlop in (:(Base.isnan), :(Base.isfinite)), end for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) - @eval begin - function Base.isapprox(x::$(T), y::Number; kwargs...) - return Base.isapprox(to_number(x), y; kwargs...) - end - - function Base.isapprox(x::Number, y::$(T); kwargs...) - return Base.isapprox(x, to_number(y); kwargs...) - end - - function Base.isapprox(x::$(T), y::$(T); kwargs...) - return Base.isapprox(to_number(x), to_number(y); kwargs...) + for (T1, T2) in ((T, Number), (Number, T), (T, T)) + @eval begin + function Base.isapprox(x::$(T1), y::$(T2); kwargs...) + return Base.isapprox(to_number(x), to_number(y); kwargs...) + end + function Base.isapprox( + x::AbstractArray{<:$(T1)}, y::AbstractArray{<:$(T2)}; kwargs... + ) + return Base.isapprox(to_number.(x), to_number.(y); kwargs...) + end end end end -function Base.isapprox(x::AnyConcreteRArray, y::AbstractArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end -function Base.isapprox(x::AbstractArray, y::AnyConcreteRArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end -function Base.isapprox(x::AnyConcreteRArray, y::AnyConcreteRArray; kwargs...) - return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) -end - -Base.:(==)(x::AnyConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::AbstractArray, y::AnyConcreteRArray) = convert(Array, x) == convert(Array, y) -function Base.:(==)(x::AnyConcreteRArray, y::AnyConcreteRArray) - return convert(Array, x) == convert(Array, y) +for (T1, T2) in ( + (AnyConcreteRArray, AbstractArray), + (AbstractArray, AnyConcreteRArray), + (AnyConcreteRArray, AnyConcreteRArray), +) + @eval begin + function Base.isapprox(x::$(T1), y::$(T2); kwargs...) + return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) + end + Base.:(==)(x::$(T1), y::$(T2)) = convert(Array, x) == convert(Array, y) + end end function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} @@ -186,7 +185,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} isempty(a) && throw("Cannot getindex from empty buffer") XLA.await(a) - if buffer_on_cpu(a) && a.sharding isa Sharding.FinalizedNoSharding + if buffer_on_cpu(a) && !Sharding.is_sharded(a) buf = get_buffer(a) GC.@preserve buf begin ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) @@ -213,7 +212,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N isempty(a) && throw("Cannot setindex! to empty buffer") XLA.await(a) - if buffer_on_cpu(a) && a.sharding isa Sharding.FinalizedNoSharding + if buffer_on_cpu(a) && !Sharding.is_sharded(a) buf = get_buffer(a) GC.@preserve buf begin ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) @@ -261,9 +260,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteR end if all(buffer_on_cpu, bc.args) && all( - x -> - !(x isa ConcreteRArray) || - (x isa ConcreteRArray && x.sharding isa Sharding.FinalizedNoSharding), + x -> !(x isa ConcreteRArray) || (x isa ConcreteRArray && !Sharding.is_sharded(x)), bc.args, ) ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) @@ -332,7 +329,7 @@ function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N} isempty(a) && throw("Cannot setindex! to empty buffer") XLA.await(a) - if buffer_on_cpu(a) && a.sharding isa Sharding.FinalizedNoSharding + if buffer_on_cpu(a) && !Sharding.is_sharded(a) buf = get_buffer(a) GC.@preserve buf begin ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) diff --git a/src/Ops.jl b/src/Ops.jl index 328825441b..da79daab72 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2142,16 +2142,15 @@ end ) return mesh( mod, - # Don't use `name_to_size` here, we need correct ordering - [k => Int64(v) for (k, v) in zip(m.axis_names, size(m.device_ids))], - collect(Int64, vec(m.device_ids)); + [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], + collect(Int64, m.device_ids); location, ) end @noinline function mesh( mod::MLIR.IR.Module, - mesh_axes::Vector{Pair{String,Int64}}, + mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}}, device_ids::Vector{Int64}; sym_name=nothing, location=mlir_stacktrace("mesh", @__FILE__, @__LINE__), @@ -2171,7 +2170,7 @@ end ctx = MLIR.IR.context() mesh_axis_attrs = [ - MLIR.API.sdyMeshAxisAttrGet(ctx, name, size) for (name, size) in mesh_axes + MLIR.API.sdyMeshAxisAttrGet(ctx, String(name), size) for (name, size) in mesh_axes ] mesh_attr = MLIR.API.sdyMeshAttrGet( ctx, diff --git a/src/Reactant.jl b/src/Reactant.jl index 4c6f106e9e..cef6bc9331 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -82,7 +82,7 @@ unwrapped_eltype(::AnyTracedRArray{T,N}) where {T,N} = T aos_to_soa(x::AbstractArray) = x aos_to_soa(x::AnyTracedRArray) = x -function aos_to_soa(x::AbstractArray{ConcreteRNumber{T}}) where {T} +function aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where {T} x_c = ConcreteRArray(zeros(T, size(x))) x_c .= x return x_c @@ -148,8 +148,10 @@ function Enzyme.make_zero( return res end -using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile -export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace, within_compile +using .Compiler: + @compile, @code_hlo, @code_mhlo, @jit, traced_getfield, create_result, compile +export ConcreteRArray, + ConcreteRNumber, @compile, @code_hlo, @code_mhlo, @jit, @trace, within_compile const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}() diff --git a/src/Sharding.jl b/src/Sharding.jl index 8c1a81c3c4..2c36ae5f52 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -7,105 +7,110 @@ using ..Reactant: Reactant, XLA # logic to directly use the sharded arrays from IFRt. This would also simplify our # logic of storing multiple arrays in ConcreteRArray struct -struct Mesh{D} - device_ids::Array{Int,D} - axis_names::NTuple{D,String} - name_to_size::Dict{String,Int} - name_to_dim::Dict{String,Int} +struct Mesh{D,ND} + device_ids::NTuple{ND,Int} + shape::Dims{D} + axis_names::NTuple{D,Symbol} - function Mesh(devices::AbstractArray{<:XLA.Device}, axis_names) + function Mesh(devices::AbstractArray{XLA.Device}, axis_names) return Mesh(XLA.DeviceGetLocalDeviceId.(devices), axis_names) end + function Mesh(devices::NTuple{D,XLA.Device}, shape::Dims{D}, axis_names) where {D} + return Mesh(XLA.DeviceGetLocalDeviceId.(devices), shape, axis_names) + end + function Mesh( - device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,String} + device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}} ) where {D} + return Mesh(Tuple(vec(device_ids)), size(device_ids), axis_names) + end + + function Mesh( + device_ids::NTuple{D1,Int}, + shape::Dims{D}, + axis_names::NTuple{D,Union{String,Symbol}}, + ) where {D,D1} @assert allunique(device_ids) - name_to_size = Dict( - name => Int64(size(device_ids, i)) for (i, name) in enumerate(axis_names) - ) - name_to_dim = Dict(name => i for (i, name) in enumerate(axis_names)) - return new{D}(Int64.(device_ids), axis_names, name_to_size, name_to_dim) + return new{D,D1}(device_ids, shape, Symbol.(axis_names)) end end -Base.length(mesh::Mesh) = length(mesh.device_ids) +Base.length(::Mesh{D,ND}) where {D,ND} = ND Base.ndims(::Mesh{D}) where {D} = D -Base.size(mesh::Mesh) = size(mesh.device_ids) + +Base.size(mesh::Mesh) = mesh.shape +Base.size(mesh::Mesh, axis::Int) = mesh.shape[axis] +function Base.size(mesh::Mesh, axis::Union{String,Symbol}) + return size(mesh, findfirst(==(Symbol(axis)), mesh.axis_names)) +end +Base.size(mesh::Mesh, ::Nothing) = 1 + +Base.in(axis::Union{String,Symbol}, mesh::Mesh) = Symbol(axis) ∈ mesh.axis_names abstract type AbstractSharding end -function (T::AbstractSharding)(::XLA.Client, device, ::AbstractArray) +function (T::AbstractSharding)(::XLA.Client, device, ::Union{AbstractArray,Number}) return error("(::$(T))(::XLA.Client, ::AbstractArray) is not implemented") end struct NoSharding <: AbstractSharding end -finalized_sharding(::Type{NoSharding}) = FinalizedNoSharding - # This allows us to mark entire branches as NoSharding Base.getproperty(::NoSharding, x) = NoSharding() -function (::NoSharding)(client::XLA.Client, device, x::AbstractArray) +function (::NoSharding)(client::XLA.Client, device, x::Union{AbstractArray,Number}) buffer = XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing) - return (buffer,), FinalizedNoSharding() + return (buffer,), ShardInfo(NoSharding(), nothing) end -struct NamedSharding{D1,P<:Tuple,D2} <: AbstractSharding - mesh::Mesh{D1} +# XXX: multiple axes partitioning -- supported by shardy (not in Jax I think) +struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding + mesh::Mesh{D1,D2} partition_spec::P - present_axes::Vector{String} - is_closed::NTuple{D2,Bool} - priority::NTuple{D2,Int} + is_closed::NTuple{D3,Bool} + priority::NTuple{D3,Int} function NamedSharding( - mesh::Mesh{D1}, + mesh::Mesh{D1,D2}, partition_spec::P; - is_closed::NTuple{D2,Bool}=ntuple(Returns(true), length(partition_spec)), + is_closed::NTuple{D3,Bool}=ntuple(Returns(true), length(partition_spec)), # negative priority means that priority is not considered by shardy - priority::NTuple{D2,Int}=ntuple(i -> -1, length(partition_spec)), - ) where {D1,P<:Tuple,D2} + priority::NTuple{D3,Int}=ntuple(i -> -1, length(partition_spec)), + ) where {D1,D2,P<:Tuple,D3} # TODO: we need to check how open sharding works in XLA, i.e. how to specify inputs @assert all(is_closed) "All partitions must be closed for now." - present_axes = String[] - for p in partition_spec - if p !== nothing - if p isa String - @assert p ∈ mesh.axis_names - push!(present_axes, p) - elseif p isa Tuple - for pᵢ in p - @assert pᵢ isa String && pᵢ ∈ mesh.axis_names - push!(present_axes, pᵢ) - end - else - error("Invalid partition spec $p") - end - end - end - @assert length(unique(present_axes)) == length(present_axes) "Duplicate axis names!" - return new{D1,P,D2}(mesh, partition_spec, present_axes, is_closed, priority) + @assert all(p -> p === nothing || p isa String || p isa Symbol, partition_spec) + partition_spec = map(x -> x isa String ? Symbol(x) : x, partition_spec) + non_replicated_axes = filter(x -> x !== nothing, partition_spec) + @assert length(unique(non_replicated_axes)) == length(non_replicated_axes) "Duplicate axis names!" + return new{D1,D2,typeof(partition_spec),D3}( + mesh, partition_spec, is_closed, priority + ) end end -finalized_sharding(::Type{<:NamedSharding}) = FinalizedNamedSharding +function (sharding::NamedSharding)(client::XLA.Client, device, x::Number) + (; mesh, partition_spec) = sharding + @assert length(partition_spec) == 0 + + data = map(mesh.device_ids) do device_id + return XLA.AsyncBuffer( + XLA.ArrayFromHostBuffer(client, fill(x), XLA.device_ordinal(client, device_id)), + nothing, + ) + end + return data, ShardInfo(sharding, ntuple(Returns(()), length(mesh))) +end -# XXX: multiple axes partitioning function (sharding::NamedSharding)(client::XLA.Client, ::Nothing, x::AbstractArray) (; mesh, partition_spec) = sharding @assert length(partition_spec) == ndims(x) # Fast Path for replicating the input across all devices if all(Base.Fix2(===, nothing), partition_spec) - data = Array{XLA.AsyncBuffer,ndims(mesh)}(undef, size(mesh)) - device_to_array_slices = Array{Vector{UnitRange{Int64}},ndims(mesh)}( - undef, size(mesh) - ) - - for idx in CartesianIndices(data) - device_id = mesh.device_ids[idx] - device_to_array_slices[idx] = [1:size(x, i) for i in 1:ndims(x)] - data[idx] = XLA.AsyncBuffer( + data = map(mesh.device_ids) do device_id + return XLA.AsyncBuffer( XLA.ArrayFromHostBuffer( client, x, @@ -116,107 +121,71 @@ function (sharding::NamedSharding)(client::XLA.Client, ::Nothing, x::AbstractArr nothing, ) end - - return ( - Tuple(vec(data)), - FinalizedNamedSharding{typeof(sharding),ndims(mesh)}( - sharding, device_to_array_slices - ), + device_to_array_slices = ntuple( + Returns(ntuple(i -> 1:size(x, i), ndims(x))), length(mesh) ) + return data, ShardInfo(sharding, device_to_array_slices) end - ndevices = Vector{Int}(undef, ndims(x)) - axis_name_to_dim_and_offset = Dict{String,Tuple{Int,Int}}() - for i in 1:ndims(x) - p = partition_spec[i] - if p === nothing - ndevices[i] = 1 - else - if p isa Tuple - offset = 0 - for pᵢ in p - axis_name_to_dim_and_offset[pᵢ] = (i, offset) - offset += mesh.name_to_size[pᵢ] - end - ndevices[i] = offset - else - axis_name_to_dim_and_offset[p] = (i, 0) - ndevices[i] = mesh.name_to_size[p] - end - end - end - + ndevices = map(Base.Fix1(size, mesh), partition_spec) for (sz, ndevice) in zip(size(x), ndevices) @assert sz % ndevice == 0 "$(size(x)) must be divisible by $(ndevices)" end strides = size(x) .÷ ndevices - slices = Array{Vector{UnitRange{Int64}},ndims(x)}(undef, Tuple(ndevices)) - + slices = Array{NTuple{ndims(x),UnitRange{Int64}},ndims(x)}(undef, ndevices) for idx in CartesianIndices(slices) idx_tup = Tuple(idx) - slices[idx] = [ + slices[idx] = Tuple( (i1 + 1):i2 for (i1, i2) in zip((idx_tup .- 1) .* strides, idx_tup .* strides) - ] + ) end - data = Array{XLA.AsyncBuffer,ndims(mesh)}(undef, size(mesh)) - device_to_array_slices = Array{Vector{UnitRange{Int64}},ndims(mesh)}(undef, size(mesh)) - - for idx in CartesianIndices(data) - device_id = mesh.device_ids[idx] + device_to_array_slices = Array{eltype(slices),ndims(mesh)}(undef, size(mesh)) + for idx in CartesianIndices(device_to_array_slices) idx_tup = Tuple(idx) slice_idx = ones(Int, ndims(slices)) for (axis_name, idxᵢ) in zip(mesh.axis_names, idx_tup) - if haskey(axis_name_to_dim_and_offset, axis_name) - dim, offset = axis_name_to_dim_and_offset[axis_name] - slice_idx[dim] = idxᵢ + offset - end + dim = findfirst(==(axis_name), sharding.partition_spec) + dim !== nothing && (slice_idx[dim] = idxᵢ) end device_to_array_slices[idx] = slices[CartesianIndex(slice_idx...)] - data[idx] = XLA.AsyncBuffer( + end + + data = ntuple(length(mesh)) do i + XLA.AsyncBuffer( XLA.ArrayFromHostBuffer( client, - x[device_to_array_slices[idx]...], + x[device_to_array_slices[i]...], XLA.ClientGetAddressableDevice( - client, XLA.device_ordinal(client, device_id) + client, XLA.device_ordinal(client, mesh.device_ids[i]) ), ), nothing, ) end - return ( - Tuple(vec(data)), - FinalizedNamedSharding{typeof(sharding),ndims(mesh)}( - sharding, device_to_array_slices - ), - ) + return data, ShardInfo(sharding, Tuple(vec(device_to_array_slices))) end -# Internal Type that mimics XYZSharding but contains mapping from device to array slices -abstract type AbstractFinalizedSharding <: AbstractSharding end +# Given Sharding + Array --> ShardInfo +struct ShardInfo{S,D} <: AbstractSharding + sharding::S + device_to_array_slices::D +end -function Base.getproperty(sharding::AbstractFinalizedSharding, name::Symbol) +function Base.getproperty(sharding::ShardInfo, name::Symbol) name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) return getfield(sharding.sharding, name) end -function (sharding::AbstractFinalizedSharding)(client::XLA.Client, device, x::AbstractArray) +function (sharding::ShardInfo)(client::XLA.Client, device, x::Union{AbstractArray,Number}) return (sharding.sharding)(client, device, x) end -struct FinalizedNoSharding <: AbstractFinalizedSharding end +const NoShardInfo = ShardInfo{NoSharding,Nothing} -function Base.getproperty(::FinalizedNoSharding, name::Symbol) - @assert name === :sharding - return NoSharding() -end - -struct FinalizedNamedSharding{S<:NamedSharding,D} <: AbstractFinalizedSharding - sharding::S - device_to_array_slices::Array{Vector{UnitRange{Int64}},D} -end +ShardInfo{NoSharding,Nothing}() = ShardInfo(NoSharding(), nothing) """ is_sharded(sharding) @@ -225,15 +194,16 @@ end Checks whether the given sharding refers to no sharding. """ is_sharded(::NoSharding) = false -is_sharded(::FinalizedNoSharding) = false is_sharded(::NamedSharding) = true -is_sharded(::FinalizedNamedSharding) = true +is_sharded(s::ShardInfo) = is_sharded(s.sharding) -function Sharding.is_sharded(x::AbstractArray) +function is_sharded(x::AbstractArray) ancestor_x = Reactant.ancestor(x) - if hasfield(typeof(ancestor_x), :sharding) - return is_sharded(ancestor_x.sharding) - end + hasfield(typeof(ancestor_x), :sharding) && return is_sharded(ancestor_x.sharding) + return false +end +function is_sharded(x::Number) + hasfield(typeof(x), :sharding) && return is_sharded(x.sharding) return false end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 47e2b3cccc..66d7ac6656 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -128,7 +128,7 @@ function transpose_val(val) end mutable struct CompiledMlirFnResult{ - F,TR,Re,Rt,LA,LR,LRS,PA,CR,M<:Union{Nothing,Reactant.Sharding.Mesh},MA + F,TR,Re,Rt,LA,LR,PA,CR,M<:Union{Nothing,Reactant.Sharding.Mesh},MA } fnwrapped::Bool f::F @@ -139,7 +139,6 @@ mutable struct CompiledMlirFnResult{ linear_args::Vector{LA} in_tys::Vector{MLIR.IR.Type} linear_results::Vector{LR} - linear_result_shard_info::LRS num_partitions::Int num_replicas::Int is_sharded::Bool @@ -224,7 +223,7 @@ function make_mlir_fn( is_sharded = false for (k, v) in seen_args if k isa Reactant.ConcreteRArray - if !(k.sharding isa Reactant.Sharding.FinalizedNoSharding) + if Reactant.Sharding.is_sharded(k) is_sharded = true traced_args_to_shardings[v] = k.sharding if !haskey(mesh_cache, k.sharding.mesh) @@ -239,11 +238,12 @@ function make_mlir_fn( unique_meshes = unique([m.mesh for (k, m) in traced_args_to_shardings]) # TODO: support multiple meshes @assert length(unique_meshes) == 1 "Currently we support using a single mesh" - sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes] - @assert allequal(sorted_devices) "All meshes must have the same device ids" - num_partitions = length(first(sorted_devices)) + # sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes] + # @assert allequal(sorted_devices) "All meshes must have the same device ids" + # num_partitions = length(first(sorted_devices)) sharding_mesh = first(unique_meshes) mesh_op_attrs = mesh_cache[sharding_mesh] + num_partitions = length(sharding_mesh) else sharding_mesh = nothing end @@ -279,18 +279,13 @@ function make_mlir_fn( for (j, name) in enumerate(sharding.partition_spec) if name === nothing axes = MLIR.IR.Attribute[] - elseif name isa String + else + @assert name isa Symbol axes = [ MLIR.API.sdyAxisRefAttrGet( - ctx, name, MLIR.API.MlirAttribute(C_NULL) + ctx, String(name), MLIR.API.MlirAttribute(C_NULL) ), ] - elseif name isa Tuple - axes = [ - MLIR.API.sdyAxisRefAttrGet( - ctx, nameᵢ, MLIR.API.MlirAttribute(C_NULL) - ) for nameᵢ in name - ] end dimension_sharding_attrs[j] = MLIR.API.sdyDimensionShardingAttrGet( ctx, length(axes), axes, sharding.is_closed[j], sharding.priority[j] @@ -438,57 +433,11 @@ function make_mlir_fn( end @assert residx > 0 result_not_replicated[residx] = true - MLIR.API.ReactantFuncSetResultAttr( + MLIR.API.mlirFuncSetResultAttr( func2, residx - 1, "sdy.sharding", linear_arg_shardings[i] ) end end - - # TODO: Introduce OpSharding in API.cpp and use it here - # XLA gives us an API to query the final result sharding. However, currently we - # don't expose OpSharding from XLA, so we manually replicate the outputs to all the - # mesh elements manually. - for (idx, already_sharded) in enumerate(result_not_replicated) - already_sharded && continue - - replicated_axes = [ - MLIR.API.sdyAxisRefAttrGet(ctx, name, MLIR.API.MlirAttribute(C_NULL)) for - name in sharding_mesh.axis_names - ] - local result = linear_results[idx] - sharding = MLIR.IR.Attribute( - MLIR.API.sdyTensorShardingAttrGet( - ctx, - mesh_op_attrs.sym_name, - ndims(result), - MLIR.API.MlirAttribute[ - MLIR.API.sdyDimensionShardingAttrGet( - ctx, 0, MLIR.API.MlirAttribute[], true, -1 - ) for _ in 1:ndims(result) - ], - length(replicated_axes), - replicated_axes, - ), - ) - MLIR.API.ReactantFuncSetResultAttr(func2, idx - 1, "sdy.sharding", sharding) - end - - linear_result_shard_info = ntuple(length(linear_results)) do i - arg = linear_results[i] - if !result_not_replicated[i] || !haskey(traced_args_to_shardings, arg) - return ( - map( - Returns([1:size(arg, i) for i in 1:ndims(arg)]), - sharding_mesh.device_ids, - ), - ntuple(Returns(nothing), ndims(arg)), - ) - end - local sharding = traced_args_to_shardings[arg] - return (sharding.device_to_array_slices, sharding.partition_spec) - end - else - linear_result_shard_info = ntuple(Returns(nothing), length(linear_results)) end MLIR.API.mlirOperationDestroy(func.operation) @@ -504,7 +453,6 @@ function make_mlir_fn( linear_args, in_tys, linear_results, - linear_result_shard_info, num_partitions, num_replicas, is_sharded, diff --git a/src/Tracing.jl b/src/Tracing.jl index 94ee29efa4..f48b7734b5 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -55,7 +55,11 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(sharding) ) if Mode == ArrayToConcrete && T <: track_numbers - return ConcreteRNumber{T} + if !Sharding.is_sharded(sharding) + return ConcreteRNumber{T,1,Sharding.NoShardInfo} + else + error("TODO: implement sharding") + end elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers return TracedRNumber{T} end @@ -236,7 +240,7 @@ Base.@nospecializeinfer function traced_type_inner( if mode == ConcreteToTraced return TracedRNumber{T} elseif mode == TracedToConcrete - return ConcreteRNumber{T} + return T0 else throw("Abstract RNumber cannot be made concrete") end @@ -295,9 +299,7 @@ Base.@nospecializeinfer function traced_type_inner( throw("TracedRArray cannot be traced") elseif mode == TracedToConcrete if !Sharding.is_sharded(sharding) - return ConcreteRArray{ - T.parameters[1],T.parameters[2],1,Sharding.FinalizedNoSharding - } + return ConcreteRArray{T.parameters[1],T.parameters[2],1,Sharding.NoShardInfo} else error("TODO: implement sharding") end @@ -318,7 +320,11 @@ Base.@nospecializeinfer function traced_type_inner( if mode == ConcreteToTraced throw("TracedRNumber cannot be traced") elseif mode == TracedToConcrete - return ConcreteRNumber{T.parameters[1]} + if !Sharding.is_sharded(sharding) + return ConcreteRNumber{T.parameters[1],1,Sharding.NoShardInfo} + else + error("TODO: implement sharding") + end elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath return T else @@ -337,7 +343,7 @@ Base.@nospecializeinfer function traced_type_inner( throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete if !Sharding.is_sharded(sharding) - return ConcreteRNG{1,Sharding.FinalizedNoSharding} + return ConcreteRNG{1,Sharding.NoShardInfo} else error("TODO: implement sharding") end @@ -403,7 +409,7 @@ Base.@nospecializeinfer function traced_type_inner( N = ndims(A) if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive if !Sharding.is_sharded(sharding) - return ConcreteRArray{T,N,1,Sharding.FinalizedNoSharding} + return ConcreteRArray{T,N,1,Sharding.NoShardInfo} else error("TODO: implement sharding") end @@ -549,8 +555,13 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive + # XXX: Sharding??? TrT = traced_type_inner( - ConcreteRNumber{SST}, seen, mode, track_numbers, sharding + ConcreteRNumber{SST,1,Sharding.ShardInfo}, + seen, + mode, + track_numbers, + sharding, ) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber @@ -864,7 +875,7 @@ function make_tracer( throw("Cannot have ConcreteRArray as function call argument.") end if mode == ArrayToConcrete - if prev.sharding isa Sharding.finalized_sharding(typeof(sharding)) + if prev.sharding isa Sharding.ShardInfo{typeof(sharding)} return prev end error( @@ -895,9 +906,11 @@ function make_tracer( throw("Cannot have ConcreteRNumber as function call argument.") end if mode == ArrayToConcrete - Sharding.is_sharded(sharding) && - error("Cannot specify sharding for ConcreteRNumber") - return prev + if !Sharding.is_sharded(sharding) + return prev + else + error("TODO: implement sharding") + end end if mode != ConcreteToTraced throw("Cannot trace existing trace type") @@ -961,8 +974,8 @@ function make_tracer( return seen[prev]::ConcreteRArray{T,N} end if !Sharding.is_sharded(sharding) - res = ConcreteRArray{T,N,1,Sharding.FinalizedNoSharding}( - (XLA.AsyncEmptyBuffer,), size(prev), Sharding.FinalizedNoSharding() + res = ConcreteRArray{T,N,1,Sharding.NoShardInfo}( + (XLA.AsyncEmptyBuffer,), size(prev), Sharding.NoShardInfo() ) else error("TODO: implement sharding") @@ -1024,9 +1037,13 @@ function make_tracer( if haskey(seen, prev) return seen[prev]::ConcreteRNumber{T} end - Sharding.is_sharded(sharding) && - error("Cannot specify sharding for ConcreteRNumber") - res = ConcreteRNumber{T}(XLA.AsyncEmptyBuffer) + if !Sharding.is_sharded(sharding) + res = ConcreteRNumber{T,1,Sharding.NoShardInfo}( + (XLA.AsyncEmptyBuffer,), Sharding.NoShardInfo() + ) + else + error("TODO: implement sharding") + end seen[prev] = res return res end @@ -1087,7 +1104,7 @@ function make_tracer( Sharding.is_sharded(sharding) && error("Cannot specify sharding for Numbers") if RT <: track_numbers if mode == ArrayToConcrete - return ConcreteRNumber(prev) + return ConcreteRNumber(prev; sharding) else if mode == TracedTrack || mode == NoStopTracedTrack res = TracedRNumber{RT}( @@ -1348,7 +1365,7 @@ end @nospecialize(track_numbers::Type), @nospecialize(sharding) ) - if x.sharding isa Sharding.finalized_sharding(typeof(sharding)) + if x.sharding isa Sharding.ShardInfo{typeof(sharding)} return x end return error( @@ -1378,8 +1395,12 @@ end @nospecialize(track_numbers::Type), @nospecialize(sharding) ) - Sharding.is_sharded(sharding) && error("Cannot specify sharding for ConcreteRNumber") - return x + if x.sharding isa Sharding.ShardInfo{typeof(sharding)} + return x + end + return error( + "Mismatched sharding. Input has sharding $(x.sharding), but requested sharding is $(typeof(sharding))", + ) end @inline function to_rarray_internal( @@ -1387,8 +1408,7 @@ end @nospecialize(track_numbers::Type), @nospecialize(sharding) ) - Sharding.is_sharded(sharding) && error("Cannot specify sharding for Numbers") - typeof(x) <: track_numbers && return ConcreteRNumber(x) + typeof(x) <: track_numbers && return ConcreteRNumber(x; sharding) return x end diff --git a/src/Types.jl b/src/Types.jl index 500eb8eade..5d73d21713 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -64,45 +64,49 @@ end # Concrete Types ## ConcreteRNumber -mutable struct ConcreteRNumber{T} <: RNumber{T} - data::XLA.AsyncBuffer +mutable struct ConcreteRNumber{T,D,S<:Sharding.ShardInfo} <: RNumber{T} + data::NTuple{D,XLA.AsyncBuffer} + sharding::S end -@leaf ConcreteRNumber +ConcreteRNumber{T,1,Sharding.NoShardInfo}(x::Number) where {T} = ConcreteRNumber{T}(x) + +function ConcreteRNumber{T}(data::Tuple{XLA.AsyncBuffer}) where {T} + return ConcreteRNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo()) +end -XLA.await(x::ConcreteRNumber) = XLA.await(x.data) -XLA.client(x::ConcreteRNumber) = XLA.client(x.data) -XLA.device(x::ConcreteRNumber) = XLA.device(x.data) +@leaf ConcreteRNumber function ConcreteRNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number} - return ConcreteRNumber{T}(ConcreteRArray(fill(convert(T, data)); kwargs...).data[1]) + carray = ConcreteRArray(fill(convert(T, data)); kwargs...) + if !Sharding.is_sharded(carray.sharding) + return ConcreteRNumber{T,1,typeof(carray.sharding)}( + (carray.data[1],), carray.sharding + ) + end + @assert all(isnothing, carray.sharding.partition_spec) "ConcreteRNumber cannot be \ + sharded" + return ConcreteRNumber{T,length(carray.data),typeof(carray.sharding)}( + carray.data, carray.sharding + ) end ConcreteRNumber(data::T; kwargs...) where {T<:Number} = ConcreteRNumber{T}(data; kwargs...) ## ConcreteRArray -# XXX: make data into a tuple of arrays -mutable struct ConcreteRArray{T,N,D,S<:Sharding.AbstractFinalizedSharding} <: RArray{T,N} +mutable struct ConcreteRArray{T,N,D,S<:Sharding.ShardInfo} <: RArray{T,N} data::NTuple{D,XLA.AsyncBuffer} shape::NTuple{N,Int} sharding::S end -# This dispatch is needed when converting a ConcreteRNumber to a 0D ConcreteRArray -function Base.setproperty!(x::ConcreteRArray, f::Symbol, val::XLA.AsyncBuffer) - @assert f === :data - return setproperty!(x, :data, (val,)) -end - @leaf ConcreteRArray Adapt.parent_type(::Type{<:ConcreteRArray{T,N}}) where {T,N} = ConcreteRArray{T,N} Adapt.parent_type(::Type{ConcreteRArray{T,N,D,S}}) where {T,N,D,S} = ConcreteRArray{T,N,D,S} Base.@deprecate ConcreteRArray(data::Number; kwargs...) ConcreteRNumber(data; kwargs...) -function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N} - return ConcreteRArray{T,N,1,Sharding.FinalizedNoSharding}( - (data,), shape, Sharding.FinalizedNoSharding() - ) +function ConcreteRArray{T,N}(data::Tuple{XLA.AsyncBuffer}, shape::NTuple{N,Int}) where {T,N} + return ConcreteRArray{T,N,1,Sharding.NoShardInfo}(data, shape, Sharding.NoShardInfo()) end function ConcreteRArray( @@ -134,16 +138,14 @@ function ConcreteRArray( ) end -XLA.await(x::ConcreteRArray) = foreach(XLA.await, x.data) -XLA.client(x::ConcreteRArray) = XLA.client(x.data) -function XLA.device(x::ConcreteRArray) - x.sharding isa Sharding.FinalizedNoSharding && return XLA.device(only(x.data)) +XLA.await(x::Union{ConcreteRArray,ConcreteRNumber}) = foreach(XLA.await, x.data) +XLA.client(x::Union{ConcreteRArray,ConcreteRNumber}) = XLA.client(x.data) +function XLA.device(x::Union{ConcreteRArray,ConcreteRNumber}) + x.sharding isa Sharding.NoShardInfo && return XLA.device(only(x.data)) return nothing # This is intentional to make constructing ConcreteRArrays easier end -const ConcreteRScalar{T} = Union{ - ConcreteRArray{T,0,1,Sharding.FinalizedNoSharding},ConcreteRNumber{T} -} +const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}} const WrappedConcreteRArray{T,N,D,S} = WrappedArray{ T,N,ConcreteRArray,ConcreteRArray{T,N,D,S} } @@ -151,14 +153,6 @@ const AnyConcreteRArray{T,N,D,S} = Union{ ConcreteRArray{T,N,D,S},WrappedConcreteRArray{T,N,D,S} } -const UnshardedConcreteRArray{T,N} = ConcreteRArray{T,N,1,Sharding.FinalizedNoSharding} -const UnshardedWrappedConcreteRArray{T,N} = WrappedConcreteRArray{ - T,N,1,Sharding.FinalizedNoSharding -} -const AnyUnshardedConcreteRArray{T,N} = AnyConcreteRArray{ - T,N,1,Sharding.FinalizedNoSharding -} - ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x) ConcreteRArray{T}(x::AnyConcreteRArray) where {T} = ConcreteRArray{T,ndims(x)}(x) ConcreteRArray{T,N}(x::ConcreteRArray{T,N}) where {T,N} = x diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 1e52b4a56b..7ec4801ff7 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -4,10 +4,12 @@ end mutable struct LoadedExecutable exec::Ptr{Cvoid} + num_results::Int64 + is_sharded::Bool - function LoadedExecutable(exec::Ptr{Cvoid}) + function LoadedExecutable(exec::Ptr{Cvoid}, num_results::Int64, is_sharded::Bool) @assert exec != C_NULL - return finalizer(free_exec, new(exec)) + return finalizer(free_exec, new(exec, num_results, is_sharded)) end end @@ -85,7 +87,9 @@ end for i in 1:n_outs push!( results, - :(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)), + :(( + AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing), + )), ) end @@ -209,22 +213,41 @@ function Compile( is_sharded::Bool=false, mesh_ids::Vector{Int64}=Int64[], # mesh_shape::Vector{Int64}=Int64[], + num_results::Int64, ) device_id = is_sharded ? Int64(-1) : Int64(device_ordinal(client, device)) mesh_ids = Int64.(device_ordinal.((client,), mesh_ids)) GC.@preserve client mod begin - return LoadedExecutable( - @ccall MLIR.API.mlir_c.ClientCompile( - client.client::Ptr{Cvoid}, - mod.module_::MLIR.API.MlirModule, - device_id::Clong, - is_sharded::Bool, - # mesh_shape::Ptr{Clong}, - # length(mesh_shape)::Clong, - mesh_ids::Ptr{Clong}, - length(mesh_ids)::Clong, - CUDA_DATA_DIR[]::Cstring, - )::Ptr{Cvoid} - ) + exec = @ccall MLIR.API.mlir_c.ClientCompile( + client.client::Ptr{Cvoid}, + mod.module_::MLIR.API.MlirModule, + device_id::Clong, + is_sharded::Bool, + # mesh_shape::Ptr{Clong}, + # length(mesh_shape)::Clong, + mesh_ids::Ptr{Clong}, + length(mesh_ids)::Clong, + CUDA_DATA_DIR[]::Cstring, + )::Ptr{Cvoid} end + return LoadedExecutable(exec, num_results, is_sharded) +end + +function get_output_shardings(exec::LoadedExecutable) + exec.is_sharded || return OpSharding[] + + jl_op_shardings = [Ref{JLOpSharding}() for _ in 1:(exec.num_results)] + jl_op_shardings_ptr = [ + Base.unsafe_convert(Ptr{JLOpSharding}, sharding) for sharding in jl_op_shardings + ] + + GC.@preserve jl_op_shardings begin + @ccall MLIR.API.mlir_c.PjRtLoadedExecutableGetOuputShardings( + exec.exec::Ptr{Cvoid}, + jl_op_shardings_ptr::Ptr{Ptr{JLOpSharding}}, + exec.num_results::Int32, + )::Cvoid + end + + return map(OpSharding ∘ getindex, jl_op_shardings) end diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl new file mode 100644 index 0000000000..1dc20c4183 --- /dev/null +++ b/src/xla/Sharding.jl @@ -0,0 +1,213 @@ +@enumx OpShardingType begin + Replicated + Maximal + Tuple + Other + Manual + Unknown +end + +@enumx ShardGroupType begin + As + Like +end + +# TODO: tuple sharding / op metadata +struct JLOpSharding + type::Int32 + n_tile_dimensions::Int32 + tile_dimensions::Ptr{Int64} + n_layout_minor_to_major::Int32 + layout_minor_to_major::Ptr{Int64} + replicate_on_last_tile_dim::Bool + n_last_tile_dims::Int32 + last_tile_dims::Ptr{Int32} + n_tile_assignment_dimensions::Int32 + tile_assignment_dimensions::Ptr{Int64} + n_tile_assignment_devices::Int32 + tile_assignment_devices::Ptr{Int64} + n_iota_reshape_dims::Int32 + iota_reshape_dims::Ptr{Int64} + n_iota_transpose_perm::Int32 + iota_transpose_perm::Ptr{Int32} + is_shard_group::Bool + shard_group_id::Int64 + shard_group_type::Int32 +end + +struct OpSharding + type::OpShardingType.T + tile_dimensions::Vector{Int64} + layout_minor_to_major::Vector{Int64} + replicate_on_last_tile_dim::Bool + last_tile_dims::Vector{OpShardingType.T} + tile_assignment_dimensions::Vector{Int64} + tile_assignment_devices::Vector{Int64} + iota_reshape_dims::Vector{Int64} + iota_transpose_perm::Vector{Int32} + is_shard_group::Bool + shard_group_id::Int64 + shard_group_type::ShardGroupType.T +end + +function OpSharding(sharding::JLOpSharding) + @assert sharding.type != 2 "Tuple sharding is not supported yet!" + + last_tile_dims = unsafe_wrap(Array, sharding.last_tile_dims, sharding.n_last_tile_dims) + tile_assignment_dimensions = unsafe_wrap( + Array, sharding.tile_assignment_dimensions, sharding.n_tile_assignment_dimensions + ) + tile_assignment_devices = unsafe_wrap( + Array, sharding.tile_assignment_devices, sharding.n_tile_assignment_devices + ) + iota_reshape_dims = unsafe_wrap( + Array, sharding.iota_reshape_dims, sharding.n_iota_reshape_dims + ) + iota_transpose_perm = unsafe_wrap( + Array, sharding.iota_transpose_perm, sharding.n_iota_transpose_perm + ) + iota_transpose_perm .+= 1 + + tile_dimensions = unsafe_wrap( + Array, sharding.tile_dimensions, sharding.n_tile_dimensions + ) + layout_minor_to_major = unsafe_wrap( + Array, sharding.layout_minor_to_major, sharding.n_layout_minor_to_major + ) + + return OpSharding( + int_to_op_sharding_type(sharding.type), + reverse(tile_dimensions), + layout_minor_to_major, + sharding.replicate_on_last_tile_dim, + reverse(last_tile_dims), + reverse(tile_assignment_dimensions), + reverse(tile_assignment_devices), + reverse(iota_reshape_dims), + reverse(iota_transpose_perm), + sharding.is_shard_group, + sharding.shard_group_id, + int_to_shard_group_type(sharding.shard_group_type), + ) +end + +function int_to_op_sharding_type(i::Int32) + i == 0 && return OpShardingType.Replicated + i == 1 && return OpShardingType.Maximal + i == 2 && return OpShardingType.Tuple + i == 3 && return OpShardingType.Other + i == 4 && return OpShardingType.Manual + i == 5 && return OpShardingType.Unknown + return error("Invalid OpShardingType $i") +end + +function int_to_shard_group_type(i::Int32) + i == 0 && return ShardGroupType.As + i == 1 && return ShardGroupType.Like + return error("Invalid ShardGroupType $i") +end + +function generate_device_list(sharding::OpSharding) + if !isempty(sharding.iota_reshape_dims) + # Generate device IDs using iota + num_devices = prod(sharding.iota_reshape_dims) + iota_devices = collect( + Int64, reshape(0:(num_devices - 1), sharding.iota_reshape_dims...) + ) + + # Permute the iota array if iota_transpose_perm is provided + if !isempty(sharding.iota_transpose_perm) + iota_devices = permutedims(iota_devices, Tuple(sharding.iota_transpose_perm)) + end + + # Flatten the permuted iota array to get tile_assignment_devices + return vec(iota_devices) + end + return sharding.tile_assignment_devices +end + +# Function to compute array indices for each device +function compute_array_indices_and_partition_spec( + sharding::OpSharding, array_size::Dims{N}, mesh +) where {N} + if sharding.type == OpShardingType.Replicated + # Replicated: All devices have the entire array + return ( + ntuple(Returns(ntuple(i -> 1:array_size[i], N)), length(mesh)), + ntuple(Returns(nothing), N), + ) + elseif sharding.type == OpShardingType.Maximal + # Maximal: Only one device has the entire array + @assert length(mesh) == 1 + return ( + ntuple(Returns(ntuple(i -> 1:array_size[i], N)), length(mesh)), + ntuple(Returns(nothing), N), + ) + elseif sharding.type == OpShardingType.Other + # Other: Tiled sharding + # Reshape tile_assignment_devices into tile_assignment_dimensions + device_list = generate_device_list(sharding) + sorted_mesh_devices = sort(collect(Int64, mesh.device_ids)) + @assert sort(device_list) == sorted_mesh_devices "Mismatched devices list: \ + $(device_list) vs \ + $(mesh.device_ids)" + @assert isempty(sharding.tile_dimensions) "Tile dimensions are not supported yet! \ + Open an issue with an MWE for this case." + @assert !sharding.replicate_on_last_tile_dim "Replication on the last tile \ + dimension is not supported yet! Open \ + an issue with an MWE for this case." + + tile_assignment = reshape(device_list, sharding.tile_assignment_dimensions...) + tile_dimensions = div.(array_size, sharding.tile_assignment_dimensions) + + mesh_devices = reshape([mesh.device_ids...], mesh.shape) + + # Match array dimensions to mesh axes by comparing device sequences + used_axes = Set{Int}() + partition_spec = ntuple(N) do dim + if dim <= length(sharding.tile_assignment_dimensions) && + sharding.tile_assignment_dimensions[dim] > 1 + tile_seq = __get_device_sequence(tile_assignment, dim) + + # For each unused mesh axis with matching size + for (axis_idx, axis_name) in enumerate(mesh.axis_names) + if axis_idx ∉ used_axes && size(mesh_devices, axis_idx) == length(tile_seq) + mesh_seq = __get_device_sequence(mesh_devices, axis_idx) + + # Check if sequences match (allowing for reversal) + if tile_seq == mesh_seq || tile_seq == reverse(mesh_seq) + push!(used_axes, axis_idx) + return axis_name + end + end + end + end + return nothing + end + + device_to_array_indices = map(mesh.device_ids) do device_id + tile_index = findfirst(==(device_id), tile_assignment) + @assert tile_index !== nothing "Device ID $device_id not found in tile assignment $tile_assignment" + tile_start = (tile_index.I .- 1) .* tile_dimensions .+ 1 + tile_end = tile_index.I .* tile_dimensions + ntuple(i -> tile_start[i]:tile_end[i], N) + end + + return device_to_array_indices, partition_spec + else + error("Unsupported sharding type: $(sharding.type)") + end +end + +# Helper function to get device sequence along a dimension +function __get_device_sequence(arr, dim) + # Take first index for all other dimensions + idx = ones(Int, ndims(arr)) + # Get sequence along target dimension + sequence = Int[] + for i in 1:size(arr, dim) + idx[dim] = i + push!(sequence, arr[idx...]) + end + return sequence +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index ba6586cd60..2973398431 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -4,6 +4,7 @@ using ..Reactant: Reactant, MLIR using Reactant_jll using Libdl using Scratch, Downloads +using EnumX: @enumx const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75) const XLA_REACTANT_GPU_PREALLOCATE = Ref{Bool}(true) @@ -21,6 +22,7 @@ end include("Client.jl") include("Device.jl") +include("Sharding.jl") include("LoadedExecutable.jl") include("Future.jl") include("Buffer.jl") diff --git a/test/autodiff.jl b/test/autodiff.jl index 1dea727bfe..26dd8b7d50 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -18,7 +18,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) ) ) - @test typeof(res1) == Tuple{ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding}} + @test typeof(res1) == Tuple{ConcreteRArray{Float64,2,1,Sharding.NoShardInfo}} @test res1[1] ≈ ores1[1] ores1 = fwd(ForwardWithPrimal, Duplicated, ones(3, 2), 3.1 * ones(3, 2)) @@ -35,8 +35,8 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) ) @test typeof(res1) == Tuple{ - ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding}, - ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding}, + ConcreteRArray{Float64,2,1,Sharding.NoShardInfo}, + ConcreteRArray{Float64,2,1,Sharding.NoShardInfo}, } @test res1[1] ≈ ores1[1] @test res1[2] ≈ ores1[2] @@ -62,7 +62,7 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y)) ) ) - @test typeof(res1) == Tuple{ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding}} + @test typeof(res1) == Tuple{ConcreteRArray{Float64,2,1,Sharding.NoShardInfo}} @test res1[1] ≈ ores1[1] end @@ -75,7 +75,9 @@ end res = @test_warn r"`Adapt.parent_type` is not implemented for" @jit gw(x) # TODO we should probably override https://github.com/EnzymeAD/Enzyme.jl/blob/5e6a82dd08e74666822b9d7b2b46c36b075668ca/src/Enzyme.jl#L2132 # to make sure this gets merged as a tracedrarray - @test typeof(res) == Tuple{Enzyme.TupleArray{ConcreteRNumber{Float64},(2, 2),4,2}} + @test typeof(res) == Tuple{ + Enzyme.TupleArray{ConcreteRNumber{Float64,1,Sharding.NoShardInfo},(2, 2),4,2} + } @test res[1] ≈ ones(2, 2) end diff --git a/test/basic.jl b/test/basic.jl index d8f4424b41..18fd60c35a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -373,7 +373,7 @@ end b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - # vcat test + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1} diff --git a/test/buffer_donation.jl b/test/buffer_donation.jl index 93bfb245fa..0ffd1b513d 100644 --- a/test/buffer_donation.jl +++ b/test/buffer_donation.jl @@ -18,7 +18,7 @@ end @jit(donate_fill_x_with_2(a, b)) @test convert(Array, a) == 2 * ones(2, 2) - (; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[2] + (; preserved_args) = Reactant.Compiler.compile_xla(donate_fill_x_with_2, (a, b))[3] preserved_args_idx = last.(preserved_args) @test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved @@ -27,7 +27,7 @@ end @jit(donate_inplace_mul(a, b)) @test convert(Array, a) == 6 * ones(2, 2) - (; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[2] + (; preserved_args) = Reactant.Compiler.compile_xla(donate_inplace_mul, (a, b))[3] preserved_args_idx = last.(preserved_args) @test preserved_args_idx == [1] # only `y`(i.e. `b`) is preserved end diff --git a/test/compile.jl b/test/compile.jl index 40711d62b6..b8dda28bd2 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -10,7 +10,8 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= x2 = Reactant.to_rarray(x) res = @jit sum(x2) - @test res isa @NamedTuple{a::Reactant.ConcreteRNumber{Float64}} + @test res isa + @NamedTuple{a::Reactant.ConcreteRNumber{Float64,1,Sharding.NoShardInfo}} @test isapprox(res.a, sum(x.a)) end diff --git a/test/integration/python.jl b/test/integration/python.jl index 128f11950c..ecc5f76d39 100644 --- a/test/integration/python.jl +++ b/test/integration/python.jl @@ -11,7 +11,7 @@ using Test jax = pyimport("jax") result = @jit jax.numpy.sum(Reactant.to_rarray(Float32[1, 2, 3])) - @test typeof(result) == ConcreteRNumber{Float32} + @test typeof(result) == ConcreteRNumber{Float32,1,Sharding.NoShardInfo} @test result ≈ 6 end end diff --git a/test/sharding.jl b/test/sharding.jl index 711927f07a..1afd03dec1 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -13,29 +13,37 @@ end @testset "Sharding Across 2 Devices" begin if length(addressable_devices) ≥ 2 mesh = Sharding.Mesh([0 1;], ("x", "y")) + fake_run = false + else + @warn "Not enough addressable devices to run sharding tests; we are running a \ + pretend test for testing purposes" + mesh = Sharding.Mesh(reshape([0], 1, 1), ("x", "y")) + fake_run = true + end - data_sharding = Sharding.NamedSharding(mesh, ("y", nothing, "x")) - data_sharding2 = Sharding.NamedSharding(mesh, (nothing, "x", nothing)) - data_sharding3 = Sharding.NamedSharding(mesh, (nothing, nothing, nothing)) # fully replicated data + data_sharding = Sharding.NamedSharding(mesh, ("y", nothing, "x")) + data_sharding2 = Sharding.NamedSharding(mesh, (nothing, "x", nothing)) + data_sharding3 = Sharding.NamedSharding(mesh, (nothing, nothing, nothing)) # fully replicated data - data = reshape(collect(1:(16 * 4 * 12)) ./ (16 * 4 * 12), 16, 4, 12) + data = reshape(collect(1:(16 * 4 * 12)) ./ (16 * 4 * 12), 16, 4, 12) - cdata = Reactant.to_rarray(data) - cdata_sharded = Reactant.to_rarray(data; sharding=data_sharding) - cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) - cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) + cdata = Reactant.to_rarray(data) + cdata_sharded = Reactant.to_rarray(data; sharding=data_sharding) + cdata_sharded2 = Reactant.to_rarray(data; sharding=data_sharding2) + cdata_sharded3 = Reactant.to_rarray(data; sharding=data_sharding3) - @test data ≈ - Array(cdata) ≈ - Array(cdata_sharded) ≈ - Array(cdata_sharded2) ≈ - Array(cdata_sharded3) + @test data ≈ + Array(cdata) ≈ + Array(cdata_sharded) ≈ + Array(cdata_sharded2) ≈ + Array(cdata_sharded3) - @test cdata_sharded.sharding isa Sharding.FinalizedNamedSharding - @test cdata_sharded2.sharding isa Sharding.FinalizedNamedSharding - @test cdata_sharded3.sharding isa Sharding.FinalizedNamedSharding - @test cdata.sharding isa Sharding.FinalizedNoSharding + @test cdata_sharded.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata_sharded2.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata_sharded3.sharding isa Sharding.ShardInfo{<:Sharding.NamedSharding} + @test cdata.sharding isa Sharding.NoShardInfo + if !fake_run true_res_y, true_res_x, true_res_z = fn_test1(data) for cd in (cdata, cdata_sharded, cdata_sharded2, cdata_sharded3) @@ -45,7 +53,31 @@ end @test Array(res_z) ≈ true_res_z @test Array(res_x) ≈ true_res_x end + end +end + +predict(samples, w1, w2) = sin.(w2 * (w1 * tanh.(samples))) + +@testset "Sharding Across 8 Devices" begin + if length(addressable_devices) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), (4, 2)), ("data", "model")) + fake_run = false else - @warn "Not enough addressable devices to run sharding tests" + @warn "Not enough addressable devices to run sharding tests; we are running a \ + pretend test for testing purposes" + mesh = Sharding.Mesh(reshape([0], 1, 1), ("data", "model")) + fake_run = true + end + + samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data")) + w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing)) + + samples = ConcreteRArray(rand(Float32, 3, 12); sharding=samples_sharding) + w1 = ConcreteRArray(rand(Float32, 4, 3); sharding=w1_sharding) + w2 = ConcreteRArray(rand(Float32, 2, 4)) + + if !fake_run + @test Array(@jit(predict(samples, w1, w2))) ≈ + predict(Array(samples), Array(w1), Array(w2)) end end diff --git a/test/struct.jl b/test/struct.jl index 398696309a..e536a74fae 100644 --- a/test/struct.jl +++ b/test/struct.jl @@ -79,7 +79,7 @@ end y = @jit(bcast_cos(x2)) @test y isa MockTensor{ - Float64,2,Reactant.ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding} + Float64,2,Reactant.ConcreteRArray{Float64,2,1,Sharding.NoShardInfo} } @test size(y) == (4, 4) @test isapprox(parent(y), bcast_cos(parent(x))) @@ -93,7 +93,7 @@ end y = @jit(bcast_cos(x2)) @test y isa MutableMockTensor{ - Float64,2,Reactant.ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding} + Float64,2,Reactant.ConcreteRArray{Float64,2,1,Sharding.NoShardInfo} } @test size(y) == (4, 4) @test isapprox(parent(y), bcast_cos(parent(x))) diff --git a/test/tracing.jl b/test/tracing.jl index 6faee5b4dd..9bf81daa12 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -58,22 +58,22 @@ end # RArray types ( - ConcreteRArray{Float64,0,1,Sharding.FinalizedNoSharding}, + ConcreteRArray{Float64,0,1,Sharding.NoShardInfo}, TracedRArray{Float64,0}, TracedRArray{Float64,0}, ), ( - ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}, + ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}, TracedRArray{Float64,1}, TracedRArray{Float64,1}, ), ( - ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding}, + ConcreteRArray{Float64,2,1,Sharding.NoShardInfo}, TracedRArray{Float64,2}, TracedRArray{Float64,2}, ), ( - ConcreteRArray{Float64,3,1,Sharding.FinalizedNoSharding}, + ConcreteRArray{Float64,3,1,Sharding.NoShardInfo}, TracedRArray{Float64,3}, TracedRArray{Float64,3}, ), @@ -81,7 +81,7 @@ end # Array types (Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}), ( - Array{ConcreteRArray{Float64,2,1,Sharding.FinalizedNoSharding},1}, + Array{ConcreteRArray{Float64,2,1,Sharding.NoShardInfo},1}, Array{TracedRArray{Float64,2},1}, Array{TracedRArray{Float64,2},1}, ), @@ -89,7 +89,7 @@ end # Union types (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}), ( - Union{Nothing,ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}}, + Union{Nothing,ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Union{Nothing,TracedRArray{Float64,1}}, Union{Nothing,TracedRArray{Float64,1}}, ), @@ -97,7 +97,7 @@ end # Ptr types (Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}), ( - Ptr{ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}}, + Ptr{ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Ptr{TracedRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}, ), @@ -107,7 +107,7 @@ end Core.LLVMPtr{TracedRNumber{Float64}}, ), ( - Core.LLVMPtr{ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}}, + Core.LLVMPtr{ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Core.LLVMPtr{TracedRArray{Float64,1}}, Core.LLVMPtr{TracedRArray{Float64,1}}, ), @@ -117,7 +117,7 @@ end Base.RefValue{TracedRNumber{Float64}}, ), ( - Base.RefValue{ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}}, + Base.RefValue{ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Base.RefValue{TracedRArray{Float64,1}}, Base.RefValue{TracedRArray{Float64,1}}, ), @@ -127,18 +127,14 @@ end (Val{0.5}, Val{0.5}, Val{0.5}), (Val{:x}, Val{:x}, Val{:x}), ( - Dict{Int,ConcreteRArray{Float64,0,1,Sharding.FinalizedNoSharding}}, + Dict{Int,ConcreteRArray{Float64,0,1,Sharding.NoShardInfo}}, Dict{Int,TracedRArray{Float64,0}}, Dict{Int,TracedRArray{Float64,0}}, ), (Dict{Int}, Dict{Int}, Dict{Int}), (Dict, Dict, Dict), ( - ( - Dict{ - A,ConcreteRArray{Float64,0,1,Sharding.FinalizedNoSharding} - } where {A} - ), + (Dict{A,ConcreteRArray{Float64,0,1,Sharding.NoShardInfo}} where {A}), (Dict{A,TracedRArray{Float64,0}} where {A}), (Dict{A,TracedRArray{Float64,0}} where {A}), ), @@ -181,16 +177,14 @@ end Wrapper{TracedRNumber{Float64},Vector{Float64}}, ), ( - Wrapper{ - Float64,ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding} - }, + Wrapper{Float64,ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Wrapper{Float64,TracedRArray{Float64,1}}, Wrapper{TracedRNumber{Float64},TracedRArray{Float64,1}}, ), (Wrapper{Symbol}, Wrapper{Symbol}, Wrapper{Symbol}), (Wrapper{Float64}, Wrapper{Float64}, Wrapper{TracedRNumber{Float64}}), ( - Wrapper{ConcreteRArray{Float64,1,1,Sharding.FinalizedNoSharding}}, + Wrapper{ConcreteRArray{Float64,1,1,Sharding.NoShardInfo}}, Wrapper{TracedRArray{Float64,1}}, Wrapper{TracedRArray{Float64,1}}, ),