@@ -64,30 +64,26 @@ function get_ancestor_indices(
6464 x:: WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} , indices...
6565) where {T,N,M}
6666 @assert length (indices) == N " Expected $N indices, got $(length (indices)) "
67+ indices = normalize_indices (x, indices... )
6768 if any (is_traced, indices)
68- final_size = Vector {Int64} (undef, N)
69- ddims = Int64[]
70- for (i, idx) in enumerate (indices)
71- @assert ndims (idx) == 1 || ndims (idx) == 0 " Unsupported feature. Please file an issue."
72- ndims (idx) == 0 && push! (ddims, i)
73- final_size[i] = length (idx)
74- end
69+ indices, integer_indices, result_size, flattened_size = traced_indices (indices... )
7570 linear_indices = mapreduce (+ , enumerate (indices)) do (i, idx)
7671 bcasted_idxs = Ops. broadcast_in_dim (
77- idx, ndims (idx) == 0 ? Int64[] : Int64[i], final_size
72+ idx, ndims (idx) == 0 ? Int64[] : Int64[i], flattened_size
7873 )
7974 Base. stride (x, i) .* (bcasted_idxs .- 1 )
8075 end
8176 linear_indices = linear_indices .+ 1
8277 parent_linear_indices_all = collect (LinearIndices (size (parent (x))))
83- parent_linear_indices = TracedUtils . promote_to (
78+ parent_linear_indices = promote_to (
8479 TracedRArray{Int64,ndims (parent_linear_indices_all)}, parent_linear_indices_all
8580 )[linear_indices]
86- isempty (ddims ) || (
81+ isempty (integer_indices ) || (
8782 parent_linear_indices = materialize_traced_array (
88- dropdims (parent_linear_indices; dims= Tuple (ddims) )
83+ dropdims (parent_linear_indices; dims= integer_indices )
8984 )
9085 )
86+ parent_linear_indices = Ops. reshape (parent_linear_indices, result_size)
9187 return (parent_linear_indices,)
9288 else
9389 # Have this as a separate code-path since we can generate non-dynamic indexing
@@ -106,7 +102,7 @@ function set_mlir_data!(
106102end
107103
108104function set_mlir_data! (x:: AnyTracedRArray{T} , data) where {T}
109- ancestor_indices = TracedUtils . get_ancestor_indices (x, axes (x)... )
105+ ancestor_indices = get_ancestor_indices (x, axes (x)... )
110106 setindex! (Reactant. ancestor (x), TracedRArray {T} (data), ancestor_indices... )
111107 return x
112108end
@@ -317,7 +313,7 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
317313struct TypeCast{T<: ReactantPrimitive } <: Function end
318314
319315function (:: TypeCast{T} )(x:: TracedRNumber{T2} ) where {T,T2}
320- return TracedUtils . promote_to (TracedRNumber{T}, x)
316+ return promote_to (TracedRNumber{T}, x)
321317end
322318
323319function elem_apply (:: Type{T} , x:: TracedRArray ) where {T<: ReactantPrimitive }
@@ -434,7 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
434430 batch_inputs = MLIR. IR. Value[]
435431
436432 for a in linear_args
437- idx, path = TracedUtils . get_argidx (a)
433+ idx, path = get_argidx (a)
438434 if idx == 1 && fnwrap
439435 push_val! (batch_inputs, f, path[3 : end ])
440436 else
@@ -455,20 +451,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
455451 residx = 1
456452
457453 for a in linear_results
458- if TracedUtils . has_residx (a)
459- path = TracedUtils . get_residx (a)
460- TracedUtils . set! (result, path[2 : end ], MLIR. IR. result (res, residx))
454+ if has_residx (a)
455+ path = get_residx (a)
456+ set! (result, path[2 : end ], MLIR. IR. result (res, residx))
461457 residx += 1
462458 else
463- idx, path = TracedUtils . get_argidx (a)
459+ idx, path = get_argidx (a)
464460 if idx == 1 && fnwrap
465- TracedUtils . set! (f, path[3 : end ], MLIR. IR. result (res, residx))
461+ set! (f, path[3 : end ], MLIR. IR. result (res, residx))
466462 residx += 1
467463 else
468464 if fnwrap
469465 idx -= 1
470466 end
471- TracedUtils . set! (args[idx], path[3 : end ], MLIR. IR. result (res, residx))
467+ set! (args[idx], path[3 : end ], MLIR. IR. result (res, residx))
472468 residx += 1
473469 end
474470 end
523519 return Ops. broadcast_in_dim (x, collect (Int64, 1 : ndims (x)), collect (Int64, rsize))
524520end
525521
522+ function normalize_indices (a:: AbstractArray , indices... )
523+ return map (enumerate (indices)) do (i, idx)
524+ idx isa Colon && return collect (Int64, 1 : size (a, i))
525+ idx isa CartesianIndex && return Tuple (idx)
526+ idx isa AbstractArray{Bool} && return findall (idx)
527+ return idx
528+ end
529+ end
530+
531+ function traced_indices (indices... )
532+ integer_indices = Int64[]
533+ result_size = Int64[]
534+ flattened_size = Int64[length (idx) for idx in indices]
535+ new_indices = map (enumerate (indices)) do (i, idx)
536+ if idx isa Number
537+ push! (integer_indices, i)
538+ idx isa TracedRNumber && return idx
539+ return promote_to (TracedRNumber{Int}, idx)
540+ end
541+ append! (result_size, [size (idx)... ])
542+ idx isa TracedRArray && return materialize_traced_array (vec (idx))
543+ return promote_to (TracedRArray{Int,1 }, vec (idx))
544+ end
545+ return new_indices, Tuple (integer_indices), result_size, flattened_size
546+ end
547+
526548end
0 commit comments