From 8c387e3b2f5cf904f3ae3b3fb82bfcee16b1fda4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 5 Oct 2023 18:22:47 +0530 Subject: [PATCH] Aggressive constprop in LinearAlgebra.wrap (#51582) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This helps with type-stability, as the flag `tA` is usually known from the type of the matrix. On master, ```julia julia> f(A) = LinearAlgebra.wrap(A, 'N') f (generic function with 1 method) julia> @code_typed f([1;;]) CodeInfo( 1 ─ %1 = invoke LinearAlgebra.wrap(A::Matrix{Int64}, 'N'::Char)::Union{Adjoint{Int64, Matrix{Int64}}, Hermitian{Int64, Matrix{Int64}}, Symmetric{Int64, Matrix{Int64}}, Transpose{Int64, Matrix{Int64}}, Matrix{Int64}} └── return %1 ) => Union{Adjoint{Int64, Matrix{Int64}}, Hermitian{Int64, Matrix{Int64}}, Symmetric{Int64, Matrix{Int64}}, Transpose{Int64, Matrix{Int64}}, Matrix{Int64}} ``` This PR ```julia julia> @code_typed f([1;;]) CodeInfo( 1 ─ return A ) => Matrix{Int64} ``` (cherry picked from commit 0fd7f72109a8741720650f72ca41b10d95e9e39e) --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 2 +- stdlib/LinearAlgebra/test/matmul.jl | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 386de771d666f..85ba1d2770ba7 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -465,7 +465,7 @@ wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h' wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's' wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's' -function wrap(A::AbstractVecOrMat, tA::AbstractChar) +Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) if tA == 'N' return A elseif tA == 'T' diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index e6000a4b24e2d..86606654e911a 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -17,6 +17,14 @@ mul_wrappers = [ m -> adjoint(m), m -> transpose(m)] +@testset "wrap" begin + f(A) = LinearAlgebra.wrap(A, 'N') + A = ones(1,1) + @test @inferred(f(A)) === A + g(A) = LinearAlgebra.wrap(A, 'T') + @test @inferred(g(A)) === transpose(A) +end + @testset "matrices with zero dimensions" begin for (dimsA, dimsB, dimsC) in ( ((0, 5), (5, 3), (0, 3)),