Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RowNonZero pivoting strategy to lu #44571

Merged
merged 6 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ Standard library changes
system image with other BLAS/LAPACK libraries is not
supported. Instead, it is recommended that the LBT mechanism be used
for swapping BLAS/LAPACK with vendor provided ones. ([#44360])
* `lu` now supports a new pivoting strategy (besides the existing
`RowMaximum()` and `NoPivot()`), that seeks the first non-zero
element among the to be factorized rows. This strategy is chosen via
`lu(A::AbstractMatrix, RowNonZero())`. ([#44571])
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
* `normalize(x, p=2)` now supports any normed vector space `x`, including scalars ([#44925]).

#### Markdown
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export
LU,
LDLt,
NoPivot,
RowNonZero,
QR,
QRPivoted,
LQ,
Expand Down Expand Up @@ -173,6 +174,7 @@ struct QRIteration <: Algorithm end

abstract type PivotingStrategy end
struct NoPivot <: PivotingStrategy end
struct RowNonZero <: PivotingStrategy end
struct RowMaximum <: PivotingStrategy end
struct ColumnNorm <: PivotingStrategy end

Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ size(F::Transpose{<:Any,<:Factorization}) = reverse(size(parent(F)))

checkpositivedefinite(info) = info == 0 || throw(PosDefException(info))
checknonsingular(info, ::RowMaximum) = info == 0 || throw(SingularException(info))
checknonsingular(info, ::RowNonZero) = info == 0 || throw(SingularException(info))
checknonsingular(info, ::NoPivot) = info == 0 || throw(ZeroPivotException(info))
checknonsingular(info) = checknonsingular(info, RowMaximum())

Expand Down
32 changes: 25 additions & 7 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
return generic_lufact!(A, pivot; check = check)
end

function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(T); check::Bool = true)
copytri!(A.data, A.uplo, isa(A, Hermitian))
lu!(A.data, pivot; check = check)
end
Expand Down Expand Up @@ -132,9 +132,9 @@ Stacktrace:
[...]
```
"""
lu!(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(eltype(A)); check::Bool = true) =
generic_lufact!(A, pivot; check = check)
function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum();
function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(T);
check::Bool = true) where {T}
# Extract values
m, n = size(A)
Expand All @@ -156,6 +156,13 @@ function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot} =
amax = absi
end
end
elseif pivot === RowNonZero()
for i = k:m
if !iszero(A[i,k])
kp = i
break
end
end
end
ipiv[k] = kp
if !iszero(A[kp,k])
Expand Down Expand Up @@ -206,6 +213,8 @@ function lutype(T::Type)
S = promote_type(T, LT, UT)
end

lupivottype(::Type{T}) where {T} = RowMaximum()

# for all other types we must promote to a type which is stable under division
"""
lu(A, pivot = RowMaximum(); check = true) -> F::LU
Expand All @@ -217,9 +226,18 @@ When `check = false`, responsibility for checking the decomposition's
validity (via [`issuccess`](@ref)) lies with the user.

In most cases, if `A` is a subtype `S` of `AbstractMatrix{T}` with an element
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`. If
pivoting is chosen (default) the element type should also support [`abs`](@ref) and
[`<`](@ref). Pivoting can be turned off by passing `pivot = NoPivot()`.
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`.

The following pivoting strategies are supported, and chosen via the optional `pivot`
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
argument:

* `RowMaximum()` (default): the standard pivoting strategy; the pivot corresponds
to the element of maximum absolute value among the remaining, to be factorized rows.
This pivoting strategy requires the element type to also support [`abs`](@ref) and
[`<`](@ref).
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
* `RowNonZero()`: the pivot corresponds to the first non-zero element among the remaining,
to be factorized rows.
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
* `NoPivot()`: pivoting turned off.
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved

The individual components of the factorization `F` can be accessed via [`getproperty`](@ref):

Expand Down Expand Up @@ -275,7 +293,7 @@ julia> l == F.L && u == F.U && p == F.p
true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(T); check::Bool = true) where {T}
lu!(_lucopy(A, lutype(T)), pivot; check = check)
end
# TODO: remove for Julia v2.0
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,28 @@ Base.adjoint(a::ModInt{n}) where {n} = ModInt{n}(conj(a))
Base.transpose(a::ModInt{n}) where {n} = a # see Issue 20978
LinearAlgebra.Adjoint(a::ModInt{n}) where {n} = adjoint(a)
LinearAlgebra.Transpose(a::ModInt{n}) where {n} = transpose(a)
LinearAlgebra.lupivottype(::Type{ModInt{n}}) where {n} = RowNonZero()

@testset "Issue 22042" begin
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
b = [ModInt{2}(1), ModInt{2}(0)]

@test A*(A\b) == b
@test A*(lu(A)\b) == b
@test A*(lu(A, NoPivot())\b) == b
@test A*(lu(A, RowNonZero())\b) == b
@test_throws MethodError lu(A, RowMaximum())

# Needed for pivoting:
Base.abs(a::ModInt{n}) where {n} = a
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k
@test A*(lu(A, RowMaximum())\b) == b

A = [ModInt{2}(0) ModInt{2}(1); ModInt{2}(1) ModInt{2}(1)]
@test A*(A\b) == b
@test A*(lu(A)\b) == b
@test A*(lu(A, RowMaximum())\b) == b
@test A*(lu(A, RowNonZero())\b) == b
end

@testset "Issue 18742" begin
Expand Down