Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ function Base.print_array(io::IO, X::AnyConcretePJRTArray)
return Base.print_array(io, convert(Array, X))
end

function Base.showarg(io::IO, a::ConcretePJRTArray{T,N}, toplevel) where {T,N}
toplevel || print(io, "::")
print(io, "ConcretePJRTArray{$T,$N}")
Sharding.is_sharded(a) && print(io, " with sharding $(typeof(a.sharding.sharding))")
return nothing
end

function Base.show(io::IO, X::AnyConcretePJRTArray)
if isempty(X)
print(io, "<Empty Buffer eltype $(eltype(X)) of size $(size(X))>")
Expand Down
2 changes: 1 addition & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module TracedRArrayOverrides

using Adapt: WrappedReshapedArray
using Adapt: WrappedReshapedArray, WrappedArray
using Base.Broadcast
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate

Expand Down
3 changes: 3 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,9 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
end

function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize)
if Reactant.ancestor(arg) isa TracedRArray
return broadcast_to_size(materialize_traced_array(arg), rsize)
end
return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize)
end
broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize)
Expand Down
19 changes: 19 additions & 0 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ..Reactant:
AnyTracedRMatrix,
AnyTracedRVector,
AnyTracedRVecOrMat,
WrappedTracedRArray,
unwrapped_eltype,
Ops,
MLIR
Expand All @@ -24,18 +25,36 @@ function TracedUtils.materialize_traced_array(
return permutedims(A, (2, 1))
end

function TracedUtils.materialize_traced_array(
x::Transpose{TracedRNumber{T},<:WrappedTracedRArray{T,N}}
) where {T,N}
return materialize_traced_array(transpose(materialize_traced_array(parent(x))))
end

function TracedUtils.materialize_traced_array(
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}
) where {T,N}
return conj(materialize_traced_array(transpose(parent(x))))
end

function TracedUtils.materialize_traced_array(
x::Adjoint{TracedRNumber{T},<:WrappedTracedRArray{T,N}}
) where {T,N}
return materialize_traced_array(adjoint(materialize_traced_array(parent(x))))
end

function TracedUtils.materialize_traced_array(
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
return diagm(parent(x))
end

function TracedUtils.materialize_traced_array(
x::Diagonal{TracedRNumber{T},WrappedTracedRArray{T,1}}
) where {T}
return diagm(materialize_traced_array(parent(x)))
end

function TracedUtils.materialize_traced_array(
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}
) where {T}
Expand Down
24 changes: 24 additions & 0 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,27 @@ end
hlo = repr(@code_hlo(permutedims_getindex(x_ra)))
@test !occursin("stablehlo.gather", hlo)
end

function view_adjoint(x)
y = view(x, 1:2, 1:2)
return adjoint(y) .+ y
end

function view_transpose(x)
y = view(x, 1:2, 1:2)
return transpose(y) .+ y
end

function view_diagonal(x)
y = view(x, 1:2, 1:2)
return Diagonal(y) .+ y
end

@testset "2 levels of wrapping" begin
x = reshape(collect(Float32, 1:8), 2, 4)
x_ra = Reactant.to_rarray(x)

@test @jit(view_adjoint(x_ra)) ≈ view_adjoint(x)
@test @jit(view_transpose(x_ra)) ≈ view_transpose(x)
@test @jit(view_diagonal(x_ra)) ≈ view_diagonal(x)
end
Loading