From 91e4c9b49848e0bcb5a6ad625ed9236f92267f88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 27 Feb 2025 18:00:18 -0500 Subject: [PATCH] fix: allow one arg of overloaded_mul to be a regular array --- src/stdlibs/LinearAlgebra.jl | 39 ++++++++++++++++++++---------- test/integration/linear_algebra.jl | 26 +++++++++----------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 7ff8bcbebd..2b141862af 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -168,12 +168,13 @@ end # Core functions function overloaded_mul!( @nospecialize(C::TracedRArray{T,1}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRVector), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractVector), α::Number=true, β::Number=false, ) where {T} - # TODO: The reshape operations are not getting optimized, we should directly call dot_general + # TODO: The reshape operations are not getting optimized, we should directly call + # dot_general rC = Ops.reshape(C, length(C), 1) overloaded_mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) @@ -182,8 +183,8 @@ end function overloaded_mul!( @nospecialize(C::TracedRArray{T,2}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRVector), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractVector), α::Number=true, β::Number=false, ) where {T} @@ -193,11 +194,14 @@ end function overloaded_mul!( @nospecialize(C::TracedRArray{T,2} where {T}), - @nospecialize(A::AnyTracedRMatrix), - @nospecialize(B::AnyTracedRMatrix), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::AbstractMatrix), α::Number=true, β::Number=false, ) + A = TracedUtils.promote_to(TracedRArray{unwrapped_eltype(A),2}, A) + B = TracedUtils.promote_to(TracedRArray{unwrapped_eltype(B),2}, B) + if size(C) != (size(A, 1), size(B, 2)) throw( DimensionMismatch( @@ -399,24 +403,33 @@ end function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T} if length(x) != length(y) - throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) + throw( + DimensionMismatch( + lazy"x has length $(length(x)), but y has length $(length(y))" + ), + ) end ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) - + set_mlir_data!(y, get_mlir_data(Ops.add(y, ax))) return y end -function LinearAlgebra.axpby!(α::Number, x::TracedRArray{T}, β::Number, y::TracedRArray{T}) where {T} +function LinearAlgebra.axpby!( + α::Number, x::TracedRArray{T}, β::Number, y::TracedRArray{T} +) where {T} if length(x) != length(y) - throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) + throw( + DimensionMismatch( + lazy"x has length $(length(x)), but y has length $(length(y))" + ), + ) end ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) by = Ops.multiply(y, TracedUtils.broadcast_to_size(T(β), size(y))) - + set_mlir_data!(y, get_mlir_data(Ops.add(ax, by))) return y end - end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 114b352524..a49768310e 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -57,6 +57,7 @@ end @test @jit(muladd2(A_ra, x_ra, b_ra)) ≈ muladd2(A, x, b) @test @jit(muladd_5arg(A_ra, x_ra, b_ra)) ≈ muladd2(A, x, b) @test @jit(muladd_5arg2(A_ra, x_ra, b_ra)) ≈ 2 .* A * x .+ b + @test @jit(A_ra * x) ≈ A * x @test @jit(mul_with_view1(A_ra, x_ra)) ≈ mul_with_view1(A, x) @@ -189,7 +190,7 @@ end x = rand(Int64, 4) x_ra = Reactant.to_rarray(x) y = rand(Int64, 4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpy!(α, x_ra, y_ra) @test y_ra ≈ axpy!(α, x, y) @@ -198,7 +199,7 @@ end x = rand(4) x_ra = Reactant.to_rarray(x) y = rand(4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpy!(α, x_ra, y_ra) @test y_ra ≈ axpy!(α, x, y) @@ -208,19 +209,18 @@ end Y = rand(3, 5) X_ra = Reactant.to_rarray(X) Y_ra = Reactant.to_rarray(Y) - + @jit axpy!(α, X_ra, Y_ra) @test Y_ra ≈ axpy!(α, X, Y) - + α = 3.2 + 1im x = rand(Complex{Float32}, 4) x_ra = Reactant.to_rarray(x) y = rand(Complex{Float32}, 4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpy!(α, x_ra, y_ra) @test y_ra ≈ axpy!(α, x, y) - end @testset "axpby!" begin @@ -229,7 +229,7 @@ end x = rand(Int64, 4) x_ra = Reactant.to_rarray(x) y = rand(Int64, 4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpby!(α, x_ra, β, y_ra) @test y_ra ≈ axpby!(α, x, β, y) @@ -239,7 +239,7 @@ end x = rand(4) x_ra = Reactant.to_rarray(x) y = rand(4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpby!(α, x_ra, β, y_ra) @test y_ra ≈ axpby!(α, x, β, y) @@ -249,21 +249,17 @@ end Y = rand(3, 5) X_ra = Reactant.to_rarray(X) Y_ra = Reactant.to_rarray(Y) - + @jit axpby!(α, X_ra, β, Y_ra) @test Y_ra ≈ axpby!(α, X, β, Y) - + α = 3.2 + 1im β = 2.1 - 4.2im x = rand(Complex{Float32}, 4) x_ra = Reactant.to_rarray(x) y = rand(Complex{Float32}, 4) - y_ra = Reactant.to_rarray(y) + y_ra = Reactant.to_rarray(y) @jit axpby!(α, x_ra, β, y_ra) @test y_ra ≈ axpby!(α, x, β, y) - end - - -