Skip to content

Commit 2118ee2

Browse files
authored
feat: more indexing support (#608)
* feat: overload LinearAlgebra.kron * test: kron * feat: more indexing support * refactor: move tests around a bit * fix: cleanup implementation and add tests
1 parent 534bea3 commit 2118ee2

File tree

6 files changed

+301
-274
lines changed

6 files changed

+301
-274
lines changed

src/TracedRArray.jl

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
154154
end
155155

156156
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
157-
indices = map(enumerate(indices)) do (idx, i)
158-
i isa Colon && return 1:size(a, idx)
159-
i isa CartesianIndex && return Tuple(i)
160-
i isa AbstractArray{<:Bool} && return findall(i)
161-
return i
162-
end
157+
indices = TracedUtils.normalize_indices(a, indices...)
163158

164159
use_gather_getindex = false
165160
for idxs in indices
@@ -168,7 +163,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
168163
use_gather_getindex = true
169164
break
170165
end
171-
contiguous = all(isone, diff(idxs))
166+
contiguous = all(isone, diff(vec(idxs)))
172167
if typeof(contiguous) <: Bool && !contiguous
173168
use_gather_getindex = true
174169
break
@@ -181,19 +176,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
181176
if any(i -> unwrapped_eltype(i) <: Bool, indices)
182177
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
183178
end
184-
idxs = map(indices) do i
185-
i isa Number && return fill(i, 1)
186-
return i
187-
end
188-
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), idxs)
189-
indices_list = generate_index_list(indices_list...)
190-
res = Ops.gather_getindex(a, indices_list)
191-
res = Ops.reshape(res, length.(idxs)...)
192-
ddims = findall(indices) do idx
193-
return idx isa Integer || idx isa TracedRNumber{<:Integer}
194-
end
195-
isempty(ddims) || return materialize_traced_array(dropdims(res; dims=Tuple(ddims)))
196-
return res
179+
indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...)
180+
res = Ops.gather_getindex(a, generate_index_list(indices...))
181+
isempty(integer_indices) ||
182+
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
183+
return Ops.reshape(res, result_size)
197184
end
198185

199186
start_indices = map(indices) do i
@@ -233,12 +220,7 @@ maybe_assert_scalar_setindexing(args...) = nothing
233220
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
234221
maybe_assert_scalar_setindexing(a, indices...)
235222

236-
indices = map(enumerate(indices)) do (idx, i)
237-
i isa Colon && return 1:size(a, idx)
238-
i isa CartesianIndex && return Tuple(i)
239-
i isa AbstractArray{<:Bool} && return findall(i)
240-
return i
241-
end
223+
indices = TracedUtils.normalize_indices(a, indices...)
242224

243225
use_scatter_setindex = false
244226
for idxs in indices

src/TracedUtils.jl

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,30 +64,26 @@ function get_ancestor_indices(
6464
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices...
6565
) where {T,N,M}
6666
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
67+
indices = normalize_indices(x, indices...)
6768
if any(is_traced, indices)
68-
final_size = Vector{Int64}(undef, N)
69-
ddims = Int64[]
70-
for (i, idx) in enumerate(indices)
71-
@assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue."
72-
ndims(idx) == 0 && push!(ddims, i)
73-
final_size[i] = length(idx)
74-
end
69+
indices, integer_indices, result_size, flattened_size = traced_indices(indices...)
7570
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
7671
bcasted_idxs = Ops.broadcast_in_dim(
77-
idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size
72+
idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size
7873
)
7974
Base.stride(x, i) .* (bcasted_idxs .- 1)
8075
end
8176
linear_indices = linear_indices .+ 1
8277
parent_linear_indices_all = collect(LinearIndices(size(parent(x))))
83-
parent_linear_indices = TracedUtils.promote_to(
78+
parent_linear_indices = promote_to(
8479
TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all
8580
)[linear_indices]
86-
isempty(ddims) || (
81+
isempty(integer_indices) || (
8782
parent_linear_indices = materialize_traced_array(
88-
dropdims(parent_linear_indices; dims=Tuple(ddims))
83+
dropdims(parent_linear_indices; dims=integer_indices)
8984
)
9085
)
86+
parent_linear_indices = Ops.reshape(parent_linear_indices, result_size)
9187
return (parent_linear_indices,)
9288
else
9389
# Have this as a separate code-path since we can generate non-dynamic indexing
@@ -106,7 +102,7 @@ function set_mlir_data!(
106102
end
107103

108104
function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
109-
ancestor_indices = TracedUtils.get_ancestor_indices(x, axes(x)...)
105+
ancestor_indices = get_ancestor_indices(x, axes(x)...)
110106
setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...)
111107
return x
112108
end
@@ -317,7 +313,7 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
317313
struct TypeCast{T<:ReactantPrimitive} <: Function end
318314

319315
function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
320-
return TracedUtils.promote_to(TracedRNumber{T}, x)
316+
return promote_to(TracedRNumber{T}, x)
321317
end
322318

323319
function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
@@ -434,7 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
434430
batch_inputs = MLIR.IR.Value[]
435431

436432
for a in linear_args
437-
idx, path = TracedUtils.get_argidx(a)
433+
idx, path = get_argidx(a)
438434
if idx == 1 && fnwrap
439435
push_val!(batch_inputs, f, path[3:end])
440436
else
@@ -455,20 +451,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
455451
residx = 1
456452

457453
for a in linear_results
458-
if TracedUtils.has_residx(a)
459-
path = TracedUtils.get_residx(a)
460-
TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx))
454+
if has_residx(a)
455+
path = get_residx(a)
456+
set!(result, path[2:end], MLIR.IR.result(res, residx))
461457
residx += 1
462458
else
463-
idx, path = TracedUtils.get_argidx(a)
459+
idx, path = get_argidx(a)
464460
if idx == 1 && fnwrap
465-
TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx))
461+
set!(f, path[3:end], MLIR.IR.result(res, residx))
466462
residx += 1
467463
else
468464
if fnwrap
469465
idx -= 1
470466
end
471-
TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
467+
set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
472468
residx += 1
473469
end
474470
end
@@ -523,4 +519,30 @@ end
523519
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
524520
end
525521

522+
function normalize_indices(a::AbstractArray, indices...)
523+
return map(enumerate(indices)) do (i, idx)
524+
idx isa Colon && return collect(Int64, 1:size(a, i))
525+
idx isa CartesianIndex && return Tuple(idx)
526+
idx isa AbstractArray{Bool} && return findall(idx)
527+
return idx
528+
end
529+
end
530+
531+
function traced_indices(indices...)
532+
integer_indices = Int64[]
533+
result_size = Int64[]
534+
flattened_size = Int64[length(idx) for idx in indices]
535+
new_indices = map(enumerate(indices)) do (i, idx)
536+
if idx isa Number
537+
push!(integer_indices, i)
538+
idx isa TracedRNumber && return idx
539+
return promote_to(TracedRNumber{Int}, idx)
540+
end
541+
append!(result_size, [size(idx)...])
542+
idx isa TracedRArray && return materialize_traced_array(vec(idx))
543+
return promote_to(TracedRArray{Int,1}, vec(idx))
544+
end
545+
return new_indices, Tuple(integer_indices), result_size, flattened_size
546+
end
547+
526548
end

0 commit comments

Comments
 (0)