From cdbe858a3ea028e47c1409851a85a33f20bf0807 Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Fri, 13 Dec 2024 10:06:11 +0100 Subject: [PATCH] sanitize the solve_triu and sync with AA (#1958) Co-authored-by: Max Horn Co-authored-by: Tommy Hofmann --- src/flint/fmpz_mat.jl | 18 +++++++++++++++--- test/flint/fmpz_mat-test.jl | 4 ++-- test/generic/Matrix-test.jl | 2 +- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/flint/fmpz_mat.jl b/src/flint/fmpz_mat.jl index dda2c543e..f94aa3c95 100644 --- a/src/flint/fmpz_mat.jl +++ b/src/flint/fmpz_mat.jl @@ -1554,7 +1554,7 @@ function _solve_dixon(a::ZZMatrix, b::ZZMatrix) end #XU = B. only the upper triangular part of U is used -function _solve_triu_left(U::ZZMatrix, b::ZZMatrix) +function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix) n = ncols(U) m = nrows(b) R = base_ring(U) @@ -1595,8 +1595,14 @@ function _solve_triu_left(U::ZZMatrix, b::ZZMatrix) return X end -#UX = B -function _solve_triu(U::ZZMatrix, b::ZZMatrix) +#UX = B, U has to be upper triangular +#I think due to the Strassen calling path, where Strasse.solve(side = :left) +#call directly AA.solve_left, this has to be in AA and cannot be independent. +function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:left) + if side == :left + return AbstractAlgebra._solve_triu_left(U, b) + end + @assert side == :right n = nrows(U) m = ncols(b) X = zero(b) @@ -1638,6 +1644,12 @@ function _solve_triu(U::ZZMatrix, b::ZZMatrix) return X end +#solves Ax = B for A lower triagular. if f != 0 (f is true), the diagonal +#is assumed to be 1 and not actually used. +#the upper part of A is not used/ touched. +#one cannot assert is_lower_triangular as this is used for the inplace +#lu decomposition where the matrix is full, encoding an upper triangular +#using the diagonal and a lower triangular with trivial diagonal function AbstractAlgebra._solve_tril!(A::ZZMatrix, B::ZZMatrix, C::ZZMatrix, f::Int = 0) # a x u ax = u diff --git a/test/flint/fmpz_mat-test.jl b/test/flint/fmpz_mat-test.jl index fcdd232dd..f361bda45 100644 --- a/test/flint/fmpz_mat-test.jl +++ b/test/flint/fmpz_mat-test.jl @@ -710,9 +710,9 @@ end @test AbstractAlgebra.Solve.matrix_normal_form_type(A) === AbstractAlgebra.Solve.HermiteFormTrait() b = matrix(ZZ, 1, 2, [1, 6]) - @test Nemo._solve_triu_left(A, b) == matrix(ZZ, 1, 2, [1, 1]) + @test AbstractAlgebra._solve_triu_left(A, b) == matrix(ZZ, 1, 2, [1, 1]) b = matrix(ZZ, 2, 1, [3, 4]) - @test Nemo._solve_triu(A, b) == matrix(ZZ, 2, 1, [1, 1]) + @test AbstractAlgebra._solve_triu(A, b; side = :right) == matrix(ZZ, 2, 1, [1, 1]) b = matrix(ZZ, 2, 1, [1, 7]) c = similar(b) AbstractAlgebra._solve_tril!(c, A, b) diff --git a/test/generic/Matrix-test.jl b/test/generic/Matrix-test.jl index 3acf4234b..28a53bf83 100644 --- a/test/generic/Matrix-test.jl +++ b/test/generic/Matrix-test.jl @@ -135,7 +135,7 @@ end M = randmat_triu(S, -100:100) b = rand(U, -100:100) - x = AbstractAlgebra._solve_triu(M, b, false) + x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false) @test M*x == b end