Skip to content

Commit

Permalink
More mul overloads (#446)
Browse files Browse the repository at this point in the history
* More mul overloads

* fixup

* fix

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Avik Pal <avikpal@mit.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 1, 2025
1 parent 913f653 commit b6096ee
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 7 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[sources.ReactantCore]
path = "lib/ReactantCore"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
Expand Down Expand Up @@ -74,3 +71,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[sources.ReactantCore]
path = "lib/ReactantCore"
8 changes: 6 additions & 2 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ for (cT, aT, bT) in (
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
A, B = aos_to_soa(A), aos_to_soa(B)
if use_overlayed_version((C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
C2 = aos_to_soa(C)
if use_overlayed_version((C2, A, B))
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
if C2 !== C
C .= C2
end
else
LinearAlgebra.mul!(C, A, B, α, β)
end
Expand Down
20 changes: 20 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using ..Reactant:
Ops,
MLIR,
ancestor,
allowscalar,
aos_to_soa,
unwrapped_eltype
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

Expand All @@ -29,6 +31,9 @@ function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
end
x isa WrappedTracedRArray &&
return convert(TracedRArray{T,N}, materialize_traced_array(x))
if eltype(x) <: TracedRNumber
return convert(TracedRArray{T,N}, aos_to_soa(x))
end
return convert(TracedRArray{T,N}, Ops.constant(collect(x)))
end

Expand Down Expand Up @@ -460,6 +465,21 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
return dest
end

function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

bc = Broadcast.preprocess(dest, bc)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
for I in 1:length(dest)
dest[I] = Reactant.@allowscalar res[I]
end
return dest
end

dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D

Expand Down
34 changes: 32 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ..Reactant:
AnyTracedRArray,
AnyTracedRMatrix,
AnyTracedRVector,
unwrapped_eltype,
Ops,
MLIR

Expand Down Expand Up @@ -190,12 +191,12 @@ function overloaded_mul!(
end

function overloaded_mul!(
@nospecialize(C::TracedRArray{T,2}),
@nospecialize(C::TracedRArray{T,2} where {T}),
@nospecialize(A::AnyTracedRMatrix),
@nospecialize(B::AnyTracedRMatrix),
α::Number=true,
β::Number=false,
) where {T}
)
if size(C) != (size(A, 1), size(B, 2))
throw(
DimensionMismatch(
Expand All @@ -207,6 +208,7 @@ function overloaded_mul!(
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
end

T = unwrapped_eltype(C)
tmp = Ops.dot_general(
T.(materialize_traced_array(A)),
T.(materialize_traced_array(B));
Expand Down Expand Up @@ -317,4 +319,32 @@ function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
return indices
end

function LinearAlgebra.ldiv!(
B::Union{
AbstractArray{<:TracedRNumber{T},1},
AbstractArray{<:TracedRNumber{T},2},
AnyTracedRArray{T,1},
AnyTracedRArray{T,2},
},
D::Diagonal,
A::AbstractVecOrMat,
) where {T}
LinearAlgebra.require_one_based_indexing(A, B)
dd = D.diag
d = length(dd)
m, n = size(A, 1), size(A, 2)
m′, n′ = size(B, 1), size(B, 2)
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
(m, n) == (m′, n′) ||
throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
B .= dd .\ A
# OG implementation below, we don't currently support the conditional throw exception
#j = findfirst(iszero, D.diag)
#isnothing(j) || throw(SingularException(j))
#@inbounds for j = 1:n, i = 1:m
# B[i, j] = dd[i] \ A[i, j]
#end
return B
end

end

0 comments on commit b6096ee

Please sign in to comment.