From 686804d2c9f94e8b51de56320dabfaf6c630c17e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 30 Jul 2024 13:29:24 +0000 Subject: [PATCH] LAPACK: Aggressive constprop to concretely infer syev!/syevd! (#55295) Currently, these are inferred as a 2-Tuple of possible return types depending on `jobz`, but since `jobz` is usually a constant, we may propagate it aggressively and have the return types inferred concretely. --- stdlib/LinearAlgebra/src/lapack.jl | 8 ++++---- stdlib/LinearAlgebra/test/lapack.jl | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/lapack.jl b/stdlib/LinearAlgebra/src/lapack.jl index e9cfacfcd0cfd..6d1d871ed85fd 100644 --- a/stdlib/LinearAlgebra/src/lapack.jl +++ b/stdlib/LinearAlgebra/src/lapack.jl @@ -5329,7 +5329,7 @@ for (syev, syevr, syevd, sygvd, elty) in # INTEGER INFO, LDA, LWORK, N # * .. Array Arguments .. # DOUBLE PRECISION A( LDA, * ), W( * ), WORK( * ) - function syev!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) + Base.@constprop :aggressive function syev!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) require_one_based_indexing(A) @chkvalidparam 1 jobz ('N', 'V') chkuplo(uplo) @@ -5429,7 +5429,7 @@ for (syev, syevr, syevd, sygvd, elty) in # * .. Array Arguments .. # INTEGER IWORK( * ) # DOUBLE PRECISION A( LDA, * ), W( * ), WORK( * ) - function syevd!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) + Base.@constprop :aggressive function syevd!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) require_one_based_indexing(A) @chkvalidparam 1 jobz ('N', 'V') chkstride1(A) @@ -5526,7 +5526,7 @@ for (syev, syevr, syevd, sygvd, elty, relty) in # * .. Array Arguments .. # DOUBLE PRECISION RWORK( * ), W( * ) # COMPLEX*16 A( LDA, * ), WORK( * ) - function syev!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) + Base.@constprop :aggressive function syev!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) require_one_based_indexing(A) @chkvalidparam 1 jobz ('N', 'V') chkstride1(A) @@ -5639,7 +5639,7 @@ for (syev, syevr, syevd, sygvd, elty, relty) in # INTEGER IWORK( * ) # DOUBLE PRECISION RWORK( * ) # COMPLEX*16 A( LDA, * ), WORK( * ) - function syevd!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) + Base.@constprop :aggressive function syevd!(jobz::AbstractChar, uplo::AbstractChar, A::AbstractMatrix{$elty}) require_one_based_indexing(A) @chkvalidparam 1 jobz ('N', 'V') chkstride1(A) diff --git a/stdlib/LinearAlgebra/test/lapack.jl b/stdlib/LinearAlgebra/test/lapack.jl index fd14dad4634a8..f05d7d99c2437 100644 --- a/stdlib/LinearAlgebra/test/lapack.jl +++ b/stdlib/LinearAlgebra/test/lapack.jl @@ -889,4 +889,14 @@ end @test UpperTriangular(A) == UpperTriangular(B) end +@testset "inference in syev!/syevd!" begin + for T in (Float32, Float64), CT in (T, Complex{T}) + A = rand(CT, 4,4) + @inferred (A -> LAPACK.syev!('N', 'U', A))(A) + @inferred (A -> LAPACK.syev!('V', 'U', A))(A) + @inferred (A -> LAPACK.syevd!('N', 'U', A))(A) + @inferred (A -> LAPACK.syevd!('V', 'U', A))(A) + end +end + end # module TestLAPACK