diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 7896325a5c..c896a8a415 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -154,12 +154,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where { end function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} - indices = map(enumerate(indices)) do (idx, i) - i isa Colon && return 1:size(a, idx) - i isa CartesianIndex && return Tuple(i) - i isa AbstractArray{<:Bool} && return findall(i) - return i - end + indices = TracedUtils.normalize_indices(a, indices...) use_gather_getindex = false for idxs in indices @@ -168,7 +163,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} use_gather_getindex = true break end - contiguous = all(isone, diff(idxs)) + contiguous = all(isone, diff(vec(idxs))) if typeof(contiguous) <: Bool && !contiguous use_gather_getindex = true break @@ -181,19 +176,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} if any(i -> unwrapped_eltype(i) <: Bool, indices) error("Boolean indexing with TracedRArrays isn't fully supported yet.") end - idxs = map(indices) do i - i isa Number && return fill(i, 1) - return i - end - indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), idxs) - indices_list = generate_index_list(indices_list...) - res = Ops.gather_getindex(a, indices_list) - res = Ops.reshape(res, length.(idxs)...) - ddims = findall(indices) do idx - return idx isa Integer || idx isa TracedRNumber{<:Integer} - end - isempty(ddims) || return materialize_traced_array(dropdims(res; dims=Tuple(ddims))) - return res + indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...) + res = Ops.gather_getindex(a, generate_index_list(indices...)) + isempty(integer_indices) || + (res = materialize_traced_array(dropdims(res; dims=integer_indices))) + return Ops.reshape(res, result_size) end start_indices = map(indices) do i @@ -233,12 +220,7 @@ maybe_assert_scalar_setindexing(args...) = nothing function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N} maybe_assert_scalar_setindexing(a, indices...) - indices = map(enumerate(indices)) do (idx, i) - i isa Colon && return 1:size(a, idx) - i isa CartesianIndex && return Tuple(i) - i isa AbstractArray{<:Bool} && return findall(i) - return i - end + indices = TracedUtils.normalize_indices(a, indices...) use_scatter_setindex = false for idxs in indices diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index e42e2d2875..d114f7c54f 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -64,30 +64,26 @@ function get_ancestor_indices( x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices... ) where {T,N,M} @assert length(indices) == N "Expected $N indices, got $(length(indices))" + indices = normalize_indices(x, indices...) if any(is_traced, indices) - final_size = Vector{Int64}(undef, N) - ddims = Int64[] - for (i, idx) in enumerate(indices) - @assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue." - ndims(idx) == 0 && push!(ddims, i) - final_size[i] = length(idx) - end + indices, integer_indices, result_size, flattened_size = traced_indices(indices...) linear_indices = mapreduce(+, enumerate(indices)) do (i, idx) bcasted_idxs = Ops.broadcast_in_dim( - idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size + idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size ) Base.stride(x, i) .* (bcasted_idxs .- 1) end linear_indices = linear_indices .+ 1 parent_linear_indices_all = collect(LinearIndices(size(parent(x)))) - parent_linear_indices = TracedUtils.promote_to( + parent_linear_indices = promote_to( TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all )[linear_indices] - isempty(ddims) || ( + isempty(integer_indices) || ( parent_linear_indices = materialize_traced_array( - dropdims(parent_linear_indices; dims=Tuple(ddims)) + dropdims(parent_linear_indices; dims=integer_indices) ) ) + parent_linear_indices = Ops.reshape(parent_linear_indices, result_size) return (parent_linear_indices,) else # Have this as a separate code-path since we can generate non-dynamic indexing @@ -106,7 +102,7 @@ function set_mlir_data!( end function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} - ancestor_indices = TracedUtils.get_ancestor_indices(x, axes(x)...) + ancestor_indices = get_ancestor_indices(x, axes(x)...) setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...) return x end @@ -317,7 +313,7 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x struct TypeCast{T<:ReactantPrimitive} <: Function end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} - return TracedUtils.promote_to(TracedRNumber{T}, x) + return promote_to(TracedRNumber{T}, x) end function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive} @@ -434,7 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} batch_inputs = MLIR.IR.Value[] for a in linear_args - idx, path = TracedUtils.get_argidx(a) + idx, path = get_argidx(a) if idx == 1 && fnwrap push_val!(batch_inputs, f, path[3:end]) else @@ -455,20 +451,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} residx = 1 for a in linear_results - if TracedUtils.has_residx(a) - path = TracedUtils.get_residx(a) - TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) + if has_residx(a) + path = get_residx(a) + set!(result, path[2:end], MLIR.IR.result(res, residx)) residx += 1 else - idx, path = TracedUtils.get_argidx(a) + idx, path = get_argidx(a) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) + set!(f, path[3:end], MLIR.IR.result(res, residx)) residx += 1 else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) residx += 1 end end @@ -523,4 +519,30 @@ end return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) end +function normalize_indices(a::AbstractArray, indices...) + return map(enumerate(indices)) do (i, idx) + idx isa Colon && return collect(Int64, 1:size(a, i)) + idx isa CartesianIndex && return Tuple(idx) + idx isa AbstractArray{Bool} && return findall(idx) + return idx + end +end + +function traced_indices(indices...) + integer_indices = Int64[] + result_size = Int64[] + flattened_size = Int64[length(idx) for idx in indices] + new_indices = map(enumerate(indices)) do (i, idx) + if idx isa Number + push!(integer_indices, i) + idx isa TracedRNumber && return idx + return promote_to(TracedRNumber{Int}, idx) + end + append!(result_size, [size(idx)...]) + idx isa TracedRArray && return materialize_traced_array(vec(idx)) + return promote_to(TracedRArray{Int,1}, vec(idx)) + end + return new_indices, Tuple(integer_indices), result_size, flattened_size +end + end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 7b2c5bbf5d..19e94b2054 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -6,6 +6,7 @@ using ..Reactant: AnyTracedRArray, AnyTracedRMatrix, AnyTracedRVector, + AnyTracedRVecOrMat, unwrapped_eltype, Ops, MLIR @@ -347,4 +348,55 @@ function LinearAlgebra.ldiv!( return B end +# Kronecker Product +function LinearAlgebra.kron( + x::AnyTracedRVecOrMat{T1}, y::AnyTracedRVecOrMat{T2} +) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), LinearAlgebra._kronsize(x, y)) + LinearAlgebra.kron!(z, x, y) + return z +end + +function LinearAlgebra.kron(x::AnyTracedRVector{T1}, y::AnyTracedRVector{T2}) where {T1,T2} + x = materialize_traced_array(x) + y = materialize_traced_array(y) + z = similar(x, Base.promote_op(*, T1, T2), length(x) * length(y)) + LinearAlgebra.kron!(z, x, y) + return z +end + +function LinearAlgebra.kron!(C::AnyTracedRVector, A::AnyTracedRVector, B::AnyTracedRVector) + LinearAlgebra.kron!( + reshape(C, length(B), length(A)), reshape(A, 1, length(A)), reshape(B, length(B), 1) + ) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRMatrix) + A = materialize_traced_array(A) + B = materialize_traced_array(B) + + final_shape = Int64[size(B, 1), size(A, 1), size(B, 2), size(A, 2)] + + A = Ops.broadcast_in_dim(A, Int64[2, 4], final_shape) + B = Ops.broadcast_in_dim(B, Int64[1, 3], final_shape) + + C_tmp = Ops.reshape(Ops.multiply(A, B), size(C)...) + set_mlir_data!(C, get_mlir_data(C_tmp)) + + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRVector, B::AnyTracedRMatrix) + LinearAlgebra._kron!(C, reshape(A, length(A), 1), B) + return C +end + +function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRVector) + LinearAlgebra._kron!(C, A, reshape(B, length(B), 1)) + return C +end + end diff --git a/test/basic.jl b/test/basic.jl index c3952549a6..c0b8363450 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -419,37 +419,6 @@ end end end -function update_on_copy(x) - y = x[1:2, 2:4, :] - y[1:1, 1:1, :] = ones(1, 1, 3) - return y -end - -@testset "view / setindex" begin - x = rand(2, 4, 3) - y = copy(x) - x_concrete = Reactant.to_rarray(x) - y_concrete = Reactant.to_rarray(y) - - y1 = update_on_copy(x) - y2 = @jit update_on_copy(x_concrete) - @test x == y - @test x_concrete == y_concrete - @test y1 == y2 - - # function update_inplace(x) - # y = view(x, 1:2, 1:2, :) - # y[1, 1, :] .= 1 - # return y - # end - - # get_indices(x) = x[1:2, 1:2, :] - # get_view(x) = view(x, 1:2, 1:2, :) - - # get_indices_compiled = @compile get_indices(x_concrete) - # get_view_compiled = @compile get_view(x_concrete) -end - function write_with_broadcast1!(x, y) x[1, :, :] .= reshape(y, 4, 3) return x @@ -483,63 +452,6 @@ end @test res[:, 1, :] ≈ view(y, :, 1:3) end -function masking(x) - y = similar(x) - y[1:2, :] .= 0 - y[3:4, :] .= 1 - return y -end - -function masking!(x) - x[1:2, :] .= 0 - x[3:4, :] .= 1 - return x -end - -@testset "setindex! with views" begin - x = rand(4, 4) .+ 2.0 - x_ra = Reactant.to_rarray(x) - - y = masking(x) - y_ra = @jit(masking(x_ra)) - @test y ≈ y_ra - - x_ra_array = Array(x_ra) - @test !(any(iszero, x_ra_array[1, :])) - @test !(any(iszero, x_ra_array[2, :])) - @test !(any(isone, x_ra_array[3, :])) - @test !(any(isone, x_ra_array[4, :])) - - y_ra = @jit(masking!(x_ra)) - @test y ≈ y_ra - - x_ra_array = Array(x_ra) - @test @allowscalar all(iszero, x_ra_array[1, :]) - @test @allowscalar all(iszero, x_ra_array[2, :]) - @test @allowscalar all(isone, x_ra_array[3, :]) - @test @allowscalar all(isone, x_ra_array[4, :]) -end - -function non_contiguous_setindex!(x) - x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 - return x -end - -@testset "non-contiguous setindex!" begin - x = rand(6, 6) - x_ra = Reactant.to_rarray(x) - - y = @jit(non_contiguous_setindex!(x_ra)) - y = Array(y) - x_ra = Array(x_ra) - @test all(isone, y[1:3, 1:4]) - @test all(isone, x_ra[1:3, 1:4]) - @test !all(isone, y[4:end, :]) - @test !all(isone, x_ra[4:end, :]) - @test !all(isone, y[:, 5:end]) - @test !all(isone, x_ra[:, 5:end]) -end - tuple_byref(x) = (; a=(; b=x)) tuple_byref2(x) = abs2.(x), tuple_byref2(x) @@ -681,19 +593,6 @@ end end end -@testset "dynamic indexing" begin - x = randn(5, 3) - x_ra = Reactant.to_rarray(x) - - idx = [1, 2, 3] - idx_ra = Reactant.to_rarray(idx) - - fn(x, idx) = @allowscalar x[idx, :] - - y = @jit(fn(x_ra, idx_ra)) - @test y ≈ x[idx, :] -end - @testset "aos_to_soa" begin using ArrayInterface @@ -822,102 +721,6 @@ end @test res[2] isa ConcreteRNumber{Float64} end -@testset "non-contiguous indexing" begin - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] - non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] - - @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) - @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing3(x) = x[[1, 3, 2], :] - non_contiguous_indexing4(x) = x[:, [1, 2, 2]] - - @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) - @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) - - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 - non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 - - @jit(non_contiguous_indexing1!(x_ra)) - non_contiguous_indexing1!(x) - @test x_ra ≈ x - - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - @jit(non_contiguous_indexing2!(x_ra)) - non_contiguous_indexing2!(x) - @test x_ra ≈ x - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 - non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 - - @jit(non_contiguous_indexing3!(x_ra)) - non_contiguous_indexing3!(x) - @test x_ra ≈ x - - x = rand(4, 2) - x_ra = Reactant.to_rarray(x) - - @jit(non_contiguous_indexing4!(x_ra)) - non_contiguous_indexing4!(x) - @test x_ra ≈ x -end - -@testset "indexing with traced arrays" begin - x = rand(4, 4, 3) - idx1 = [1, 3, 2] - idx3 = [1, 2, 1, 3] - - x_ra = Reactant.to_rarray(x) - idx1_ra = Reactant.to_rarray(idx1) - idx3_ra = Reactant.to_rarray(idx3) - - getindex1(x, idx1) = x[idx1, :, :] - getindex2(x, idx1) = x[:, idx1, :] - getindex3(x, idx3) = x[:, :, idx3] - getindex4(x, idx1, idx3) = x[idx1, :, idx3] - - @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) - @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) - @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) - @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) -end - -@testset "linear indexing" begin - x = rand(4, 4, 3) - x_ra = Reactant.to_rarray(x) - - getindex_linear_scalar(x, idx) = @allowscalar x[idx] - - @testset for i in 1:length(x) - @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) - @test @jit( - getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number)) - ) ≈ getindex_linear_scalar(x, i) - end - - idx = rand(1:length(x), 8) - idx_ra = Reactant.to_rarray(idx) - - getindex_linear_vector(x, idx) = x[idx] - - @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) - @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) -end - @testset "stack" begin x = rand(4, 4) y = rand(4, 4) @@ -985,18 +788,6 @@ end @test @jit(s4(x, y)) isa Any end -@testset "Boolean Indexing" begin - x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) - idxs_ra = Reactant.to_rarray(rand(Bool, 16)) - - fn(x, idxs) = x[:, idxs] - - @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) - - res = @jit fn(x_ra, Array(idxs_ra)) - @test res ≈ fn(Array(x_ra), Array(idxs_ra)) -end - @testset "duplicate args (#226)" begin first_arg(x, y) = x x_ra = Reactant.to_rarray(rand(2, 2)) @@ -1052,25 +843,6 @@ end @test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number)) end -@testset "inconsistent indexing" begin - x_ra = Reactant.to_rarray(rand(3, 4, 3)) - idx_ra = Reactant.to_rarray(1; track_numbers=Number) - - fn1(x) = x[:, :, 1] - fn2(x, idx) = x[:, :, idx] - fn3(x, idx) = x[idx, :, 1] - - @test ndims(@jit(fn1(x_ra))) == 2 - @test ndims(@jit(fn2(x_ra, idx_ra))) == 2 - @test ndims(@jit(fn3(x_ra, idx_ra))) == 1 -end - -@testset "reshaped subarray indexing" begin - fn(x) = view(x, 1:2) .+ 1 - x_ra = Reactant.to_rarray(rand(3, 4, 3)) - @test @jit(fn(x_ra)) == fn(Array(x_ra)) -end - @testset "reduce integers" begin x = rand(Bool, 100) x_ra = Reactant.to_rarray(x) diff --git a/test/indexing.jl b/test/indexing.jl new file mode 100644 index 0000000000..ce4f84ae7f --- /dev/null +++ b/test/indexing.jl @@ -0,0 +1,235 @@ +using LinearAlgebra, Reactant, Test + +function update_on_copy(x) + y = x[1:2, 2:4, :] + y[1:1, 1:1, :] = ones(1, 1, 3) + return y +end + +@testset "view / setindex" begin + x = rand(2, 4, 3) + y = copy(x) + x_concrete = Reactant.to_rarray(x) + y_concrete = Reactant.to_rarray(y) + + y1 = update_on_copy(x) + y2 = @jit update_on_copy(x_concrete) + @test x == y + @test x_concrete == y_concrete + @test y1 == y2 + + # function update_inplace(x) + # y = view(x, 1:2, 1:2, :) + # y[1, 1, :] .= 1 + # return y + # end + + # get_indices(x) = x[1:2, 1:2, :] + # get_view(x) = view(x, 1:2, 1:2, :) + + # get_indices_compiled = @compile get_indices(x_concrete) + # get_view_compiled = @compile get_view(x_concrete) +end + +function masking(x) + y = similar(x) + y[1:2, :] .= 0 + y[3:4, :] .= 1 + return y +end + +function masking!(x) + x[1:2, :] .= 0 + x[3:4, :] .= 1 + return x +end + +@testset "setindex! with views" begin + x = rand(4, 4) .+ 2.0 + x_ra = Reactant.to_rarray(x) + + y = masking(x) + y_ra = @jit(masking(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test !(any(iszero, x_ra_array[1, :])) + @test !(any(iszero, x_ra_array[2, :])) + @test !(any(isone, x_ra_array[3, :])) + @test !(any(isone, x_ra_array[4, :])) + + y_ra = @jit(masking!(x_ra)) + @test y ≈ y_ra + + x_ra_array = Array(x_ra) + @test @allowscalar all(iszero, x_ra_array[1, :]) + @test @allowscalar all(iszero, x_ra_array[2, :]) + @test @allowscalar all(isone, x_ra_array[3, :]) + @test @allowscalar all(isone, x_ra_array[4, :]) +end + +function non_contiguous_setindex!(x) + x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0 + return x +end + +@testset "non-contiguous setindex!" begin + x = rand(6, 6) + x_ra = Reactant.to_rarray(x) + + y = @jit(non_contiguous_setindex!(x_ra)) + y = Array(y) + x_ra = Array(x_ra) + @test all(isone, y[1:3, 1:4]) + @test all(isone, x_ra[1:3, 1:4]) + @test !all(isone, y[4:end, :]) + @test !all(isone, x_ra[4:end, :]) + @test !all(isone, y[:, 5:end]) + @test !all(isone, x_ra[:, 5:end]) +end + +@testset "dynamic indexing" begin + x = randn(5, 3) + x_ra = Reactant.to_rarray(x) + + idx = [1, 2, 3] + idx_ra = Reactant.to_rarray(idx) + + fn(x, idx) = @allowscalar x[idx, :] + + y = @jit(fn(x_ra, idx_ra)) + @test y ≈ x[idx, :] +end + +@testset "non-contiguous indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1(x) = x[[1, 3, 2], :, :] + non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]] + + @test @jit(non_contiguous_indexing1(x_ra)) ≈ non_contiguous_indexing1(x) + @test @jit(non_contiguous_indexing2(x_ra)) ≈ non_contiguous_indexing2(x) + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3(x) = x[[1, 3, 2], :] + non_contiguous_indexing4(x) = x[:, [1, 2, 2]] + + @test @jit(non_contiguous_indexing3(x_ra)) ≈ non_contiguous_indexing3(x) + @test @jit(non_contiguous_indexing4(x_ra)) ≈ non_contiguous_indexing4(x) + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2 + non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2 + + @jit(non_contiguous_indexing1!(x_ra)) + non_contiguous_indexing1!(x) + @test x_ra ≈ x + + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing2!(x_ra)) + non_contiguous_indexing2!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2 + non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2 + + @jit(non_contiguous_indexing3!(x_ra)) + non_contiguous_indexing3!(x) + @test x_ra ≈ x + + x = rand(4, 2) + x_ra = Reactant.to_rarray(x) + + @jit(non_contiguous_indexing4!(x_ra)) + non_contiguous_indexing4!(x) + @test x_ra ≈ x +end + +@testset "indexing with traced arrays" begin + x = rand(4, 4, 3) + idx1 = [1, 3, 2] + idx3 = [1, 2, 1, 3] + + x_ra = Reactant.to_rarray(x) + idx1_ra = Reactant.to_rarray(idx1) + idx3_ra = Reactant.to_rarray(idx3) + + getindex1(x, idx1) = x[idx1, :, :] + getindex2(x, idx1) = x[:, idx1, :] + getindex3(x, idx3) = x[:, :, idx3] + getindex4(x, idx1, idx3) = x[idx1, :, idx3] + + @test @jit(getindex1(x_ra, idx1_ra)) ≈ getindex1(x, idx1) + @test @jit(getindex2(x_ra, idx1_ra)) ≈ getindex2(x, idx1) + @test @jit(getindex3(x_ra, idx3_ra)) ≈ getindex3(x, idx3) + @test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) ≈ getindex4(x, idx1, idx3) +end + +@testset "linear indexing" begin + x = rand(4, 4, 3) + x_ra = Reactant.to_rarray(x) + + getindex_linear_scalar(x, idx) = @allowscalar x[idx] + + @testset for i in 1:length(x) + @test @jit(getindex_linear_scalar(x_ra, i)) ≈ getindex_linear_scalar(x, i) + @test @jit( + getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number)) + ) ≈ getindex_linear_scalar(x, i) + end + + idx = rand(1:length(x), 8) + idx_ra = Reactant.to_rarray(idx) + + getindex_linear_vector(x, idx) = x[idx] + + @test @jit(getindex_linear_vector(x_ra, idx_ra)) ≈ getindex_linear_vector(x, idx) + @test @jit(getindex_linear_vector(x_ra, idx)) ≈ getindex_linear_vector(x, idx) +end + +@testset "Boolean Indexing" begin + x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) + idxs_ra = Reactant.to_rarray(rand(Bool, 16)) + + fn(x, idxs) = x[:, idxs] + + @test_throws ErrorException @jit(fn(x_ra, idxs_ra)) + + res = @jit fn(x_ra, Array(idxs_ra)) + @test res ≈ fn(Array(x_ra), Array(idxs_ra)) +end + +@testset "inconsistent indexing" begin + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + idx_ra = Reactant.to_rarray(1; track_numbers=Number) + + fn1(x) = x[:, :, 1] + fn2(x, idx) = x[:, :, idx] + fn3(x, idx) = x[idx, :, 1] + + @test ndims(@jit(fn1(x_ra))) == 2 + @test ndims(@jit(fn2(x_ra, idx_ra))) == 2 + @test ndims(@jit(fn3(x_ra, idx_ra))) == 1 +end + +@testset "High-Dimensional Array Indexing" begin + x_ra = Reactant.to_rarray(rand(5, 4, 3)) + idx1_ra = Reactant.to_rarray(rand(1:5, 2, 2, 3)) + idx2_ra = Reactant.to_rarray(rand(1:4, 2, 2, 3)) + idx3 = rand(1:3, 2, 2, 3) + + fn(x, idx1, idx2, idx3) = x[idx1, idx2, idx3] + + @test @jit(fn(x_ra, idx1_ra, idx2_ra, idx3)) ≈ + fn(Array(x_ra), Array(idx1_ra), Array(idx2_ra), idx3) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index ea39556f95..cd804d150e 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -169,3 +169,17 @@ mul_symmetric(x) = Symmetric(x) * x @test @jit(fn(x_ra)) ≈ fn(x) end end + +@testset "kron" begin + @testset for T in (Int64, Float64, ComplexF64) + @testset for (x_sz, y_sz) in [ + ((3, 4), (2, 5)), ((3, 4), (2,)), ((3,), (2, 5)), ((3,), (5,)), ((10,), ()) + ] + x = x_sz == () ? rand(T) : rand(T, x_sz) + y = y_sz == () ? rand(T) : rand(T, y_sz) + x_ra = Reactant.to_rarray(x; track_numbers=Number) + y_ra = Reactant.to_rarray(y; track_numbers=Number) + @test @jit(kron(x_ra, y_ra)) ≈ kron(x, y) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index be17750042..9c5e036909 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Indexing" include("indexing.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 5069e665d0..da8ae34eb2 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -213,6 +213,11 @@ function broadcast_reshaped_array(x, idx1, idx2::Number) return y[idx1, idx2] .+ 1 end +function broadcast_reshaped_array(x, idx1) + y = reshape(x, 20, 2) + return y[idx1, :] .+ 1 +end + @testset "Broadcast reshaped array" begin x_ra = Reactant.to_rarray(rand(5, 4, 2)) idx1_ra = Reactant.to_rarray(rand(1:20, 4)) @@ -227,4 +232,14 @@ end @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra), Int64(idx3)) ≈ @jit(broadcast_reshaped_array(x_ra, idx1_ra, idx3)) ≈ @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Int64(idx3))) + + @test broadcast_reshaped_array(Array(x_ra), Array(idx1_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, idx1_ra)) ≈ + @jit(broadcast_reshaped_array(x_ra, Array(idx1_ra))) +end + +@testset "reshaped subarray indexing" begin + fn(x) = view(x, 1:2) .+ 1 + x_ra = Reactant.to_rarray(rand(3, 4, 3)) + @test @jit(fn(x_ra)) == fn(Array(x_ra)) end