Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ext/ReactantOffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,12 @@ function Base.getindex(a::OffsetVector{<:TracedRNumber}, indices::Int)
return parent(a)[J]
end

function Reactant.TracedUtils.get_ancestor_and_indices_inner(
x::OffsetArray{<:TracedRNumber,N}, indices::Vararg{Any,N}
) where {N}
return Reactant.TracedUtils.get_ancestor_and_indices(
parent(x), map(parentindex, axes(x), indices)...
)
end

end
12 changes: 8 additions & 4 deletions src/Indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ end
function Base.getindex(
a::AnyTracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, index...)...)
ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, index...)
return getindex(ancestor, idxs...)
end

function Base.getindex(
Expand All @@ -93,15 +94,18 @@ function Base.getindex(
end

function Base.getindex(a::AnyTracedRArray{T,N}, linear_indices) where {T,N}
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, linear_indices)...)
ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, linear_indices)
return getindex(ancestor, idxs...)
end

function Base.getindex(a::AnyTracedRArray{T,1}, indices) where {T}
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices)...)
ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, indices)
return getindex(ancestor, idxs...)
end

function Base.getindex(a::AnyTracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
ancestor, idxs = TracedUtils.get_ancestor_and_indices(a, indices...)
return getindex(ancestor, idxs...)
end

### Specialize certain dispatches for better codegen
Expand Down
54 changes: 28 additions & 26 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ function set_mlir_data!(x::Base.ReshapedArray{TracedRNumber{T}}, data) where {T}
return x
end

function get_ancestor_indices(
function get_ancestor_and_indices(
x::Base.ReshapedArray{TracedRNumber{T},N}, indices::Vector{CartesianIndex{N}}
) where {T,N}
linear_indices = LinearIndices(size(x))[indices]
parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices]
return (parent_linear_indices,)
return (parent(x), (parent_linear_indices,))
end

function get_ancestor_indices(
function get_ancestor_and_indices(
x::Base.ReshapedArray{TracedRNumber{T},N}, indices...
) where {T,N}
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
Expand All @@ -134,13 +134,13 @@ function get_ancestor_indices(
)
)
parent_linear_indices = @opcall reshape(parent_linear_indices, result_size)
return (parent_linear_indices,)
return (parent(x), (parent_linear_indices,))
else
# Have this as a separate code-path since we can generate non-dynamic indexing
cartesian_indices = CartesianIndex.(Iterators.product(indices...))
linear_indices = LinearIndices(size(x))[cartesian_indices]
parent_linear_indices = LinearIndices(size(parent(x)))[linear_indices]
return (parent_linear_indices,)
return (parent(x), (parent_linear_indices,))
end
end

Expand All @@ -152,53 +152,55 @@ function set_mlir_data!(
end

function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
ancestor_indices = get_ancestor_indices(x, axes(x)...)
ancestor, ancestor_indices = get_ancestor_and_indices(x, axes(x)...)
setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...)
return x
end

get_ancestor_indices(::TracedRArray, indices) = indices
get_ancestor_indices(::TracedRArray, indices, args...) = (indices, args...)
get_ancestor_and_indices(a::TracedRArray, indices) = (a, indices)
get_ancestor_and_indices(a::TracedRArray, indices, args...) = (a, (indices, args...))

get_ancestor_indices(::Array{<:TracedRNumber}, indices...) = indices
get_ancestor_indices(::Array{<:TracedRNumber}, indices, args...) = (indices, args...)
get_ancestor_and_indices(a::Array{<:TracedRNumber}, indices...) = (a, indices)
function get_ancestor_and_indices(a::Array{<:TracedRNumber}, indices, args...)
return (a, (indices, args...))
end

function get_ancestor_indices(x::AnyTracedRArray, indices...)
return get_ancestor_indices_inner(x, indices...) # redirect to avoid ambiguity
function get_ancestor_and_indices(x::AnyTracedRArray, indices...)
return get_ancestor_and_indices_inner(x, indices...) # redirect to avoid ambiguity
end
function get_ancestor_indices(x::AnyTracedRArray, indices, args...)
return get_ancestor_indices_inner(x, indices, args...) # redirect to avoid ambiguity
function get_ancestor_and_indices(x::AnyTracedRArray, indices, args...)
return get_ancestor_and_indices_inner(x, indices, args...) # redirect to avoid ambiguity
end

function get_ancestor_indices_inner(
function get_ancestor_and_indices_inner(
x::AnyTracedRArray{T,N}, indices::Vararg{Any,N}
) where {T,N}
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
return get_ancestor_and_indices(parent(x), Base.reindex(parentindices(x), indices)...)
end
function get_ancestor_indices_inner(x::AnyTracedRArray{T,1}, indices) where {T}
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices))
function get_ancestor_and_indices_inner(x::AnyTracedRArray{T,1}, indices) where {T}
return get_ancestor_and_indices(parent(x), Base.reindex(parentindices(x), indices))
end

function get_ancestor_indices_inner(
function get_ancestor_and_indices_inner(
x::AnyTracedRArray{T,N}, linear_indices::AbstractArray
) where {T,N}
idxs = _get_ancestor_indices_linear(x, linear_indices)
return idxs isa Tuple ? idxs : (idxs,)
a, idxs = _get_ancestor_and_indices_linear(x, linear_indices)
return a, (idxs isa Tuple ? idxs : (idxs,))
end
function get_ancestor_indices_inner(
function get_ancestor_and_indices_inner(
x::AnyTracedRArray{T,1}, linear_indices::AbstractArray
) where {T}
idxs = _get_ancestor_indices_linear(x, linear_indices)
return idxs isa Tuple ? idxs : (idxs,)
a, idxs = _get_ancestor_and_indices_linear(x, linear_indices)
return a, (idxs isa Tuple ? idxs : (idxs,))
end

function _get_ancestor_indices_linear(x::AnyTracedRArray, indices::AbstractArray)
function _get_ancestor_and_indices_linear(x::AnyTracedRArray, indices::AbstractArray)
indices = CartesianIndices(x)[indices]
pidxs = parentindices(x)
parent_indices = map(indices) do idx
CartesianIndex(Base.reindex(pidxs, (idx.I...,)))
end
return get_ancestor_indices(parent(x), parent_indices)
return get_ancestor_and_indices(parent(x), parent_indices)
end

Base.@nospecializeinfer function batch_ty(
Expand Down
11 changes: 11 additions & 0 deletions test/integration/offsetarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,14 @@ end
tval = @jit scalar_index(rOA)
@test tval ≈ oval
end

@testset "OffsetArray View" begin
U = zeros(Float64, 128, 128, 1)
vU = OffsetArray(U, -7:120, -7:120, 1:1)
rU = Reactant.to_rarray(vU)

@jit fill!(@view(rU[1:112, 1:112, 1]), 1.0)
fill!(@view(vU[1:112, 1:112, 1]), 1.0)

@test parent(rU) ≈ parent(vU)
end
Loading