From fe20d5c3a6b85bbf85b2a672ea1085a54a99740f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 22:11:33 -0500 Subject: [PATCH 1/2] feat: support more set indexing --- src/Ops.jl | 6 +++-- src/TracedRArray.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++ test/indexing.jl | 44 ++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 03f4789851..4cd41c3baf 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1479,11 +1479,13 @@ instead. @noinline function scatter_setindex( dest::TracedRArray{T,N}, scatter_indices::TracedRArray{Int64,2}, - updates::TracedRArray{T,1}, -) where {T,N} + updates::TracedRArray{T2,1}, +) where {T,N,T2} @assert length(updates) == size(scatter_indices, 1) @assert size(scatter_indices, 2) == N + updates = convert(TracedRArray{T,1}, updates) + update_computation = MLIR.IR.Region() block = MLIR.IR.Block( [mlir_type(TracedRNumber{T}), mlir_type(TracedRNumber{T})], diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 358d0c8642..de8f5f5564 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -216,6 +216,61 @@ end maybe_assert_scalar_setindexing(args...) = nothing +function Base.setindex!( + a::TracedRArray{T,N}, v, indices::Union{Int,TracedRNumber{Int}} +) where {T,N} + GPUArraysCore.assertscalar( + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" + ) + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = scalar_index_to_cartesian( + TracedUtils.broadcast_to_size(indices, (1,)), size(a) + ) + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} + if !(indices isa TracedRArray) + indices = collect(indices) + eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) + indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) + end + v = v isa Number ? v : vec(v) + v = materialize_traced_array(TracedUtils.broadcast_to_size(v, size(a))) + res = Ops.scatter_setindex(a, scalar_index_to_cartesian(vec(indices), size(a)), v) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, ::Colon) where {T,N} + v = TracedUtils.broadcast_to_size(v, size(a)) + set_mlir_data!(a, get_mlir_data(v)) + return a +end + +function Base.setindex!(a::TracedRArray{T,N}, v, indices::CartesianIndex{N}) where {T,N} + GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::CartesianIndex{N})") + indices = + materialize_traced_array( + reshape( + TracedUtils.promote_to( + TracedRArray{Int,1}, collect(Int64, vcat(Tuple(indices)...)) + ), + 1, + N, + ), + ) .- 1 + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} maybe_assert_scalar_setindexing(a, indices...) diff --git a/test/indexing.jl b/test/indexing.jl index ce4f84ae7f..9a34145049 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -233,3 +233,47 @@ end @test @jit(fn(x_ra, idx1_ra, idx2_ra, idx3)) ≈ fn(Array(x_ra), Array(idx1_ra), Array(idx2_ra), idx3) end + +function issue_617(outf, fr, pr, I) + tmp = fr .* reshape(pr, size(fr)) + outv = @view outf[I] + vtmp = vec(tmp) + outv .= vtmp + return outf +end + +@testset "issue #617" begin + N, M = 4, 6 + + f = rand(ComplexF64, N, N) + p = rand(ComplexF64, N * N) + I = 1:(N^2) + out = rand(ComplexF64, M, M) + + fr = Reactant.to_rarray(f) + pr = Reactant.to_rarray(p) + outr = Reactant.to_rarray(out) + Ir = Reactant.to_rarray(I) + + @test @jit(issue_617(outr, fr, pr, Ir)) ≈ issue_617(out, f, p, I) +end + +function scalar_setindex(x, idx, val) + @allowscalar x[idx] = val + return x +end + +@testset "scalar setindex" begin + x = zeros(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(scalar_setindex(x_ra, 1, 1)) ≈ scalar_setindex(x, 1, 1) + @test @allowscalar x_ra[1] == 1 + + x = zeros(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(scalar_setindex(x_ra, ConcreteRNumber(1), 1)) ≈ scalar_setindex(x, 1, 1) + @test @allowscalar x_ra[1] == 1 +end + From ca54c571fcb95c85554342682b4ac883cc7c751e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 Jan 2025 23:54:27 -0500 Subject: [PATCH 2/2] fix: tests --- src/TracedRArray.jl | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index de8f5f5564..f78b0ba8f9 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -234,15 +234,36 @@ function Base.setindex!( return a end +# Avoid ambiguity +function Base.setindex!( + a::TracedRArray{T,1}, v, indices::Union{Int,TracedRNumber{Int}} +) where {T} + GPUArraysCore.assertscalar( + "setindex!(::TracedRArray, v, ::Union{Int, TracedRNumber{Int}})" + ) + if indices isa Int + indices = TracedUtils.promote_to(TracedRNumber{Int}, indices) + end + indices = scalar_index_to_cartesian( + TracedUtils.broadcast_to_size(indices, (1,)), size(a) + ) + v = v isa Number ? v : vec(v) + res = Ops.scatter_setindex(a, indices, TracedUtils.broadcast_to_size(v, (1,))) + set_mlir_data!(a, get_mlir_data(res)) + return a +end + function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N} if !(indices isa TracedRArray) indices = collect(indices) eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices]) indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices) end - v = v isa Number ? v : vec(v) - v = materialize_traced_array(TracedUtils.broadcast_to_size(v, size(a))) - res = Ops.scatter_setindex(a, scalar_index_to_cartesian(vec(indices), size(a)), v) + res = Ops.scatter_setindex( + a, + scalar_index_to_cartesian(vec(indices), size(a)), + materialize_traced_array(vec(v)), + ) set_mlir_data!(a, get_mlir_data(res)) return a end @@ -299,7 +320,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices) indices_list = generate_index_list(indices_list...) res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v))) - a.mlir_data = res.mlir_data + set_mlir_data!(a, get_mlir_data(res)) return v end @@ -330,7 +351,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { ), 1, ) - a.mlir_data = res + set_mlir_data!(a, res) return v end