Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {
end

function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end
indices = TracedUtils.normalize_indices(a, indices...)

use_gather_getindex = false
for idxs in indices
Expand All @@ -168,7 +163,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
use_gather_getindex = true
break
end
contiguous = all(isone, diff(idxs))
contiguous = all(isone, diff(vec(idxs)))
if typeof(contiguous) <: Bool && !contiguous
use_gather_getindex = true
break
Expand All @@ -181,19 +176,11 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
if any(i -> unwrapped_eltype(i) <: Bool, indices)
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
end
idxs = map(indices) do i
i isa Number && return fill(i, 1)
return i
end
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), idxs)
indices_list = generate_index_list(indices_list...)
res = Ops.gather_getindex(a, indices_list)
res = Ops.reshape(res, length.(idxs)...)
ddims = findall(indices) do idx
return idx isa Integer || idx isa TracedRNumber{<:Integer}
end
isempty(ddims) || return materialize_traced_array(dropdims(res; dims=Tuple(ddims)))
return res
indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...)
res = Ops.gather_getindex(a, generate_index_list(indices...))
isempty(integer_indices) ||
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
return Ops.reshape(res, result_size)
end

start_indices = map(indices) do i
Expand Down Expand Up @@ -233,12 +220,7 @@ maybe_assert_scalar_setindexing(args...) = nothing
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
maybe_assert_scalar_setindexing(a, indices...)

indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
i isa AbstractArray{<:Bool} && return findall(i)
return i
end
indices = TracedUtils.normalize_indices(a, indices...)

use_scatter_setindex = false
for idxs in indices
Expand Down
62 changes: 42 additions & 20 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,26 @@ function get_ancestor_indices(
x::WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}}, indices...
) where {T,N,M}
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
indices = normalize_indices(x, indices...)
if any(is_traced, indices)
final_size = Vector{Int64}(undef, N)
ddims = Int64[]
for (i, idx) in enumerate(indices)
@assert ndims(idx) == 1 || ndims(idx) == 0 "Unsupported feature. Please file an issue."
ndims(idx) == 0 && push!(ddims, i)
final_size[i] = length(idx)
end
indices, integer_indices, result_size, flattened_size = traced_indices(indices...)
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
bcasted_idxs = Ops.broadcast_in_dim(
idx, ndims(idx) == 0 ? Int64[] : Int64[i], final_size
idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size
)
Base.stride(x, i) .* (bcasted_idxs .- 1)
end
linear_indices = linear_indices .+ 1
parent_linear_indices_all = collect(LinearIndices(size(parent(x))))
parent_linear_indices = TracedUtils.promote_to(
parent_linear_indices = promote_to(
TracedRArray{Int64,ndims(parent_linear_indices_all)}, parent_linear_indices_all
)[linear_indices]
isempty(ddims) || (
isempty(integer_indices) || (
parent_linear_indices = materialize_traced_array(
dropdims(parent_linear_indices; dims=Tuple(ddims))
dropdims(parent_linear_indices; dims=integer_indices)
)
)
parent_linear_indices = Ops.reshape(parent_linear_indices, result_size)
return (parent_linear_indices,)
else
# Have this as a separate code-path since we can generate non-dynamic indexing
Expand All @@ -106,7 +102,7 @@ function set_mlir_data!(
end

function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
ancestor_indices = TracedUtils.get_ancestor_indices(x, axes(x)...)
ancestor_indices = get_ancestor_indices(x, axes(x)...)
setindex!(Reactant.ancestor(x), TracedRArray{T}(data), ancestor_indices...)
return x
end
Expand Down Expand Up @@ -317,7 +313,7 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
struct TypeCast{T<:ReactantPrimitive} <: Function end

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
return promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
Expand Down Expand Up @@ -434,7 +430,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
batch_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = TracedUtils.get_argidx(a)
idx, path = get_argidx(a)
if idx == 1 && fnwrap
push_val!(batch_inputs, f, path[3:end])
else
Expand All @@ -455,20 +451,20 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
residx = 1

for a in linear_results
if TracedUtils.has_residx(a)
path = TracedUtils.get_residx(a)
TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx))
if has_residx(a)
path = get_residx(a)
set!(result, path[2:end], MLIR.IR.result(res, residx))
residx += 1
else
idx, path = TracedUtils.get_argidx(a)
idx, path = get_argidx(a)
if idx == 1 && fnwrap
TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx))
set!(f, path[3:end], MLIR.IR.result(res, residx))
residx += 1
else
if fnwrap
idx -= 1
end
TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
set!(args[idx], path[3:end], MLIR.IR.result(res, residx))
residx += 1
end
end
Expand Down Expand Up @@ -523,4 +519,30 @@ end
return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize))
end

function normalize_indices(a::AbstractArray, indices...)
return map(enumerate(indices)) do (i, idx)
idx isa Colon && return collect(Int64, 1:size(a, i))
idx isa CartesianIndex && return Tuple(idx)
idx isa AbstractArray{Bool} && return findall(idx)
return idx
end
end

function traced_indices(indices...)
integer_indices = Int64[]
result_size = Int64[]
flattened_size = Int64[length(idx) for idx in indices]
new_indices = map(enumerate(indices)) do (i, idx)
if idx isa Number
push!(integer_indices, i)
idx isa TracedRNumber && return idx
return promote_to(TracedRNumber{Int}, idx)
end
append!(result_size, [size(idx)...])
idx isa TracedRArray && return materialize_traced_array(vec(idx))
return promote_to(TracedRArray{Int,1}, vec(idx))
end
return new_indices, Tuple(integer_indices), result_size, flattened_size
end

end
52 changes: 52 additions & 0 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ..Reactant:
AnyTracedRArray,
AnyTracedRMatrix,
AnyTracedRVector,
AnyTracedRVecOrMat,
unwrapped_eltype,
Ops,
MLIR
Expand Down Expand Up @@ -347,4 +348,55 @@ function LinearAlgebra.ldiv!(
return B
end

# Kronecker Product
function LinearAlgebra.kron(
x::AnyTracedRVecOrMat{T1}, y::AnyTracedRVecOrMat{T2}
) where {T1,T2}
x = materialize_traced_array(x)
y = materialize_traced_array(y)
z = similar(x, Base.promote_op(*, T1, T2), LinearAlgebra._kronsize(x, y))
LinearAlgebra.kron!(z, x, y)
return z
end

function LinearAlgebra.kron(x::AnyTracedRVector{T1}, y::AnyTracedRVector{T2}) where {T1,T2}
x = materialize_traced_array(x)
y = materialize_traced_array(y)
z = similar(x, Base.promote_op(*, T1, T2), length(x) * length(y))
LinearAlgebra.kron!(z, x, y)
return z
end

function LinearAlgebra.kron!(C::AnyTracedRVector, A::AnyTracedRVector, B::AnyTracedRVector)
LinearAlgebra.kron!(
reshape(C, length(B), length(A)), reshape(A, 1, length(A)), reshape(B, length(B), 1)
)
return C
end

function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRMatrix)
A = materialize_traced_array(A)
B = materialize_traced_array(B)

final_shape = Int64[size(B, 1), size(A, 1), size(B, 2), size(A, 2)]

A = Ops.broadcast_in_dim(A, Int64[2, 4], final_shape)
B = Ops.broadcast_in_dim(B, Int64[1, 3], final_shape)

C_tmp = Ops.reshape(Ops.multiply(A, B), size(C)...)
set_mlir_data!(C, get_mlir_data(C_tmp))

return C
end

function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRVector, B::AnyTracedRMatrix)
LinearAlgebra._kron!(C, reshape(A, length(A), 1), B)
return C
end

function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTracedRVector)
LinearAlgebra._kron!(C, A, reshape(B, length(B), 1))
return C
end

end
Loading
Loading