diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl index d71bf72f38..cbb691124c 100644 --- a/ext/ReactantOffsetArraysExt.jl +++ b/ext/ReactantOffsetArraysExt.jl @@ -63,4 +63,12 @@ function Base.getindex(a::OffsetVector{<:TracedRNumber}, indices::Int) return parent(a)[J] end +function Reactant.TracedUtils.get_ancestor_and_indices_inner( + x::OffsetArray{<:TracedRNumber,N}, indices::Vararg{Any,N} +) where {N} + return Reactant.TracedUtils.get_ancestor_and_indices( + parent(x), map(parentindex, axes(x), indices)... + ) +end + end diff --git a/src/Indexing.jl b/src/Indexing.jl index 7ff156f694..16c4484f11 100644 --- a/src/Indexing.jl +++ b/src/Indexing.jl @@ -83,7 +83,8 @@ end function Base.getindex( a::AnyTracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, index...)...) + ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, index...) + return getindex(ancestor, idxs...) end function Base.getindex( @@ -93,15 +94,18 @@ function Base.getindex( end function Base.getindex(a::AnyTracedRArray{T,N}, linear_indices) where {T,N} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, linear_indices)...) + ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, linear_indices) + return getindex(ancestor, idxs...) end function Base.getindex(a::AnyTracedRArray{T,1}, indices) where {T} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices)...) + ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, indices) + return getindex(ancestor, idxs...) end function Base.getindex(a::AnyTracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} - return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) + ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, indices...) + return getindex(ancestor, idxs...) end ### Specialize certain dispatches for better codegen diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 260e81e714..a61dfa289e 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -100,15 +100,15 @@ function set_mlir_data!(x::Base.ReshapedArray{TracedRNumber{T}}, data) where {T} return x end -function get_ancestor_indices( +function get_ancestor_and_indices( x::Base.ReshapedArray{TracedRNumber{T},N}, indices::Vector{CartesianIndex{N}} ) where {T,N} linear_indices = LinearIndices(size(x))[indices] parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices] - return (parent_linear_indices,) + return (parent(x), (parent_linear_indices,)) end -function get_ancestor_indices( +function get_ancestor_and_indices( x::Base.ReshapedArray{TracedRNumber{T},N}, indices... ) where {T,N} @assert length(indices) == N "Expected $N indices, got $(length(indices))" @@ -134,13 +134,13 @@ function get_ancestor_indices( ) ) parent_linear_indices = @opcall reshape(parent_linear_indices, result_size) - return (parent_linear_indices,) + return (parent(x), (parent_linear_indices,)) else # Have this as a separate code-path since we can generate non-dynamic indexing cartesian_indices = CartesianIndex.(Iterators.product(indices...)) linear_indices = LinearIndices(size(x))[cartesian_indices] parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices] - return (parent_linear_indices,) + return (parent(x), (parent_linear_indices,)) end end @@ -152,53 +152,55 @@ function set_mlir_data!( end function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T} - ancestor_indices = get_ancestor_indices(x, axes(x)...) + ancestor, ancestor_indices = get_ancestor_and_indices(x, axes(x)...) setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...) return x end -get_ancestor_indices(::TracedRArray, indices) = indices -get_ancestor_indices(::TracedRArray, indices, args...) = (indices, args...) +get_ancestor_and_indices(a::TracedRArray, indices) = (a, indices) +get_ancestor_and_indices(a::TracedRArray, indices, args...) = (a, (indices, args...)) -get_ancestor_indices(::Array{<:TracedRNumber}, indices...) = indices -get_ancestor_indices(::Array{<:TracedRNumber}, indices, args...) = (indices, args...) +get_ancestor_and_indices(a::Array{<:TracedRNumber}, indices...) = (a, indices) +function get_ancestor_and_indices(a::Array{<:TracedRNumber}, indices, args...) + return (a, (indices, args...)) +end -function get_ancestor_indices(x::AnyTracedRArray, indices...) - return get_ancestor_indices_inner(x, indices...) # redirect to avoid ambiguity +function get_ancestor_and_indices(x::AnyTracedRArray, indices...) + return get_ancestor_and_indices_inner(x, indices...) # redirect to avoid ambiguity end -function get_ancestor_indices(x::AnyTracedRArray, indices, args...) - return get_ancestor_indices_inner(x, indices, args...) # redirect to avoid ambiguity +function get_ancestor_and_indices(x::AnyTracedRArray, indices, args...) + return get_ancestor_and_indices_inner(x, indices, args...) # redirect to avoid ambiguity end -function get_ancestor_indices_inner( +function get_ancestor_and_indices_inner( x::AnyTracedRArray{T,N}, indices::Vararg{Any,N} ) where {T,N} - return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) + return get_ancestor_and_indices(parent(x), Base.reindex(parentindices(x), indices)...) end -function get_ancestor_indices_inner(x::AnyTracedRArray{T,1}, indices) where {T} - return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)) +function get_ancestor_and_indices_inner(x::AnyTracedRArray{T,1}, indices) where {T} + return get_ancestor_and_indices(parent(x), Base.reindex(parentindices(x), indices)) end -function get_ancestor_indices_inner( +function get_ancestor_and_indices_inner( x::AnyTracedRArray{T,N}, linear_indices::AbstractArray ) where {T,N} - idxs = _get_ancestor_indices_linear(x, linear_indices) - return idxs isa Tuple ? idxs : (idxs,) + a, idxs = _get_ancestor_and_indices_linear(x, linear_indices) + return a, (idxs isa Tuple ? idxs : (idxs,)) end -function get_ancestor_indices_inner( +function get_ancestor_and_indices_inner( x::AnyTracedRArray{T,1}, linear_indices::AbstractArray ) where {T} - idxs = _get_ancestor_indices_linear(x, linear_indices) - return idxs isa Tuple ? idxs : (idxs,) + a, idxs = _get_ancestor_and_indices_linear(x, linear_indices) + return a, (idxs isa Tuple ? idxs : (idxs,)) end -function _get_ancestor_indices_linear(x::AnyTracedRArray, indices::AbstractArray) +function _get_ancestor_and_indices_linear(x::AnyTracedRArray, indices::AbstractArray) indices = CartesianIndices(x)[indices] pidxs = parentindices(x) parent_indices = map(indices) do idx CartesianIndex(Base.reindex(pidxs, (idx.I...,))) end - return get_ancestor_indices(parent(x), parent_indices) + return get_ancestor_and_indices(parent(x), parent_indices) end Base.@nospecializeinfer function batch_ty( diff --git a/test/integration/offsetarrays.jl b/test/integration/offsetarrays.jl index a8bb4f899e..bf6d533ad1 100644 --- a/test/integration/offsetarrays.jl +++ b/test/integration/offsetarrays.jl @@ -17,3 +17,14 @@ end tval = @jit scalar_index(rOA) @test tval ≈ oval end + +@testset "OffsetArray View" begin + U = zeros(Float64, 128, 128, 1) + vU = OffsetArray(U, -7:120, -7:120, 1:1) + rU = Reactant.to_rarray(vU) + + @jit fill!(@view(rU[1:112, 1:112, 1]), 1.0) + fill!(@view(vU[1:112, 1:112, 1]), 1.0) + + @test parent(rU) ≈ parent(vU) +end