Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorInterface #65

Merged
merged 61 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
3575759
minimalvec VectorInterface
lkdvos Dec 1, 2022
4efb684
lanczos factorisations
lkdvos Dec 1, 2022
cf4ef9a
revert orthonormal to non-breaking
lkdvos Dec 1, 2022
e606650
small changes
lkdvos Dec 1, 2022
2fcc721
arnoldi working
lkdvos Dec 1, 2022
651fc91
gkl factorizations
lkdvos Dec 1, 2022
4fe184e
cg linsolve
lkdvos Dec 1, 2022
0ae46fe
gmres working
lkdvos Dec 1, 2022
e12c880
working bicgstab
lkdvos Dec 1, 2022
8222705
fix typo
lkdvos Dec 1, 2022
1015060
eigsolve working
lkdvos Dec 1, 2022
8c2d97c
schursolve
lkdvos Dec 1, 2022
70f6eda
geneigsolve
lkdvos Dec 1, 2022
b46c0e4
svdsolve
lkdvos Dec 1, 2022
3458c6a
fix missing conj
lkdvos Dec 1, 2022
17ba6ad
expintegrator, docs and minimalvec
lkdvos Dec 1, 2022
93d1b20
vectypes
lkdvos Dec 1, 2022
a74d1e8
small changes
lkdvos Dec 1, 2022
7ca7656
add compat VectorInterface
lkdvos Dec 1, 2022
a29d80b
formatting/style
lkdvos Dec 1, 2022
04ab151
inplace add!
lkdvos Dec 1, 2022
d00ce95
fix typo
lkdvos Dec 1, 2022
edeb75d
bangbang factorisations
lkdvos Dec 2, 2022
2de745b
bangbang gkl
lkdvos Dec 2, 2022
17f01c7
linsolve bangbang
lkdvos Dec 2, 2022
f66c754
bangbang eigsolve
lkdvos Dec 2, 2022
1ce9db9
bangbang geineigsolve but doesnt work?
lkdvos Dec 7, 2022
5df6d9a
update tests
lkdvos Dec 7, 2022
d024f13
recursivevec and svdsolve and expintegrator
lkdvos Dec 19, 2022
8936c8b
update vectorinterface implementations
lkdvos Dec 23, 2022
55f471f
fix formatting
lkdvos Dec 23, 2022
2693ae2
bangbang expintegrator
lkdvos Dec 23, 2022
7978bb8
remove stray typo
lkdvos Jan 3, 2023
e8f96f0
fix formatter complaint
lkdvos Jan 3, 2023
4eeda28
AD support for linsolve and eigsolve
lkdvos Jan 20, 2023
b48060d
fix dependency
lkdvos Feb 11, 2023
1f6cbd9
Add manual CI trigger
lkdvos May 23, 2023
6f21eea
VectorInterface updates and small fixes
lkdvos Oct 16, 2023
750519d
Merge remote-tracking branch 'origin/master' into vectorinterface
lkdvos Feb 9, 2024
3ad86f6
Formatter
lkdvos Feb 9, 2024
9d70505
Remove duplicate norm definitions
lkdvos Feb 22, 2024
e649abb
Some test updates
lkdvos Feb 22, 2024
9e07794
Explicitly import Testsetup code
lkdvos Feb 22, 2024
e4539ed
Formatter
lkdvos Feb 22, 2024
b783de4
Undo accidental file delete
lkdvos Feb 22, 2024
4d457fa
remove legacy format-check
lkdvos Feb 22, 2024
2d64a66
Remove add! and scale! definitions for outplace minimalvec
lkdvos Feb 22, 2024
35204d6
fix stray undefined variable
lkdvos Feb 22, 2024
dc4cee3
Fix typo
lkdvos Feb 23, 2024
0ab35b5
Replace `dot` with `inner`
lkdvos Feb 23, 2024
41b92e0
change some stray linearalgebra operations
lkdvos Feb 23, 2024
4d90360
Fix some ad rules
lkdvos Feb 23, 2024
689152a
formatter
lkdvos Feb 23, 2024
7aa723f
remove eigsolve AD rules
lkdvos Mar 5, 2024
81d64a1
Rewrite tests
lkdvos Mar 5, 2024
c2c45b0
Add `stack` in tests for VERSION < 1.9
lkdvos Mar 5, 2024
53ccce5
Revert computing AD adjoint operator change
lkdvos Mar 7, 2024
84c96a4
Clean up Givens rotation
lkdvos Mar 7, 2024
a58cced
Address comments
lkdvos Mar 7, 2024
d31b115
Rename `precision` to `tolerance`
lkdvos Mar 7, 2024
7a27157
Refactor tests II
lkdvos Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.jl.*.cov
*.jl.mem
.DS_Store
Manifest.toml
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
name = "KrylovKit"
uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
authors = ["Jutho Haegeman"]
version = "0.6.1"
version = "0.7.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[compat]
Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FiniteDifferences = "0.12"
GPUArraysCore = "0.1"
VectorInterface = "0.4"
LinearAlgebra = "1"
Random = "1"
Printf = "1"
Expand Down
20 changes: 4 additions & 16 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,8 @@ However, KrylovKit.jl distinguishes itself from the previous packages in the fol
2. KrylovKit does not assume that the vectors involved in the problem are actual subtypes
of `AbstractVector`. Any Julia object that behaves as a vector is supported, so in
particular higher-dimensional arrays or any custom user type that supports the
following functions (with `v` and `w` two instances of this type and `α, β` scalars
(i.e. `Number`)):
* `Base.:*(α, v)`: multiply `v` with a scalar `α`, which can be of a different scalar
type; in particular this method is used to create vectors similar to `v` but with a
different type of underlying scalars.
* `Base.similar(v)`: a way to construct vectors which are exactly similar to `v`
* `LinearAlgebra.mul!(w, v, α)`: out of place scalar multiplication; multiply
vector `v` with scalar `α` and store the result in `w`
* `LinearAlgebra.rmul!(v, α)`: in-place scalar multiplication of `v` with `α`; in
particular with `α = false`, `v` is the corresponding zero vector
* `LinearAlgebra.axpy!(α, v, w)`: store in `w` the result of `α*v + w`
* `LinearAlgebra.axpby!(α, v, β, w)`: store in `w` the result of `α*v + β*w`
* `LinearAlgebra.dot(v,w)`: compute the inner product of two vectors
* `LinearAlgebra.norm(v)`: compute the 2-norm of a vector
interface as defined in
[`VectorInterface.jl`](https://github.com/Jutho/VectorInterface.jl)

Algorithms in KrylovKit.jl are tested against such a minimal implementation (named
`MinimalVec`) in the test suite. This type is only defined in the tests. However,
Expand All @@ -84,14 +72,14 @@ However, KrylovKit.jl distinguishes itself from the previous packages in the fol
* [`RecursiveVec`](@ref) can be used for grouping a set of vectors into a single
vector like structure (can be used recursively). This is more robust than trying to
use nested `Vector{<:Vector}` types.
* [`InnerProductVec`](@ref) can be used to redefine the inner product (i.e. `dot`)
* [`InnerProductVec`](@ref) can be used to redefine the inner product (i.e. `inner`)
and corresponding norm (`norm`) of an already existing vector like object. The
latter should help with implementing certain type of preconditioners.

## Current functionality

The following algorithms are currently implemented
* `linsolve`: [`CG`](@ref), [`GMRES`](@ref)
* `linsolve`: [`CG`](@ref), [`GMRES`](@ref), [`BiCGStab`](@ref)
* `eigsolve`: a Krylov-Schur algorithm (i.e. with tick restarts) for extremal eigenvalues
of normal (i.e. not generalized) eigenvalue problems, corresponding to
[`Lanczos`](@ref) for real symmetric or complex hermitian linear maps, and to
Expand Down
4 changes: 2 additions & 2 deletions docs/src/man/implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ KrylovKit.orthonormalize

The expansion coefficients of a general vector in terms of a given orthonormal basis can be obtained as
```@docs
KrylovKit.project!
KrylovKit.project!!
```
whereas the inverse calculation is obtained as
```@docs
KrylovKit.unproject!
KrylovKit.unproject!!
```

An orthonormal basis can be transformed using a rank-1 update using
Expand Down
5 changes: 3 additions & 2 deletions src/KrylovKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ The high level interface of KrylovKit is provided by the following functions:
computes a linear combination of the `ϕⱼ` functions which generalize `ϕ₀(z) = exp(z)`.
"""
module KrylovKit

using VectorInterface
using VectorInterface: add!!
using LinearAlgebra
using Printf
using ChainRulesCore
using GPUArraysCore
const IndexRange = AbstractRange{Int}

export linsolve, eigsolve, geneigsolve, svdsolve, schursolve, exponentiate, expintegrator
export orthogonalize, orthogonalize!, orthonormalize, orthonormalize!
export orthogonalize, orthogonalize!!, orthonormalize, orthonormalize!!
export basis, rayleighquotient, residual, normres, rayleighextension
export initialize, initialize!, expand!, shrink!
export ClassicalGramSchmidt, ClassicalGramSchmidt2, ClassicalGramSchmidtIR
Expand Down
32 changes: 17 additions & 15 deletions src/adrules/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,20 @@
∂self = NoTangent()
∂x₀ = ZeroTangent()
∂algorithm = NoTangent()
(∂b, reverse_info) = linsolve(fᴴ, x̄, (zero(a₀) * zero(a₁)) * x̄, algorithm,
conj(a₀), conj(a₁))
T = VectorInterface.promote_scale(VectorInterface.promote_scale(x̄, a₀),
scalartype(a₁))
∂b, reverse_info = linsolve(fᴴ, x̄, zerovector(x̄, T), algorithm, conj(a₀),
conj(a₁))
if reverse_info.converged == 0
@warn "Linear problem for reverse rule did not converge." reverse_info
end
∂f = @thunk(f_pullback(-conj(a₁) * ∂b)[1])
∂a₀ = @thunk(-dot(x, ∂b))
∂f = @thunk(f_pullback(scale(∂b, -conj(a₁)))[1])
∂a₀ = @thunk(-inner(x, ∂b))
# ∂a₁ = @thunk(-dot(f(x), ∂b))
if a₀ == zero(a₀) && a₁ == one(a₁)
∂a₁ = @thunk(-dot(b, ∂b))
∂a₁ = @thunk(-inner(b, ∂b))

Check warning on line 80 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L80

Added line #L80 was not covered by tests
else
∂a₁ = @thunk(-dot((b - a₀ * x) / a₁, ∂b))
∂a₁ = @thunk(-inner(scale!!(add(b, x, -a₀), inv(a₁)), ∂b))
end
return ∂self, ∂f, ∂b, ∂x₀, ∂algorithm, ∂a₀, ∂a₁
end
Expand All @@ -91,17 +93,17 @@
(x, info) = linsolve(A, b, x₀, algorithm, a₀, a₁)

if Δb isa ChainRulesCore.AbstractZero
rhs = zero(b)
rhs = zerovector(b)

Check warning on line 96 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L96

Added line #L96 was not covered by tests
else
rhs = (1 - Δa₁) * Δb
rhs = scale(Δb, (1 - Δa₁))

Check warning on line 98 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L98

Added line #L98 was not covered by tests
end
if !iszero(Δa₀)
rhs = axpy!(-Δa₀, x, rhs)
rhs = add!!(rhs, x, -Δa₀)

Check warning on line 101 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L101

Added line #L101 was not covered by tests
end
if !iszero(ΔA)
rhs = mul!(rhs, ΔA, x, -a₁, true)
end
(Δx, forward_info) = linsolve(A, rhs, zero(rhs), algorithm, a₀, a₁)
(Δx, forward_info) = linsolve(A, rhs, zerovector(rhs), algorithm, a₀, a₁)

Check warning on line 106 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L106

Added line #L106 was not covered by tests
if info.converged > 0 && forward_info.converged == 0
@warn "The tangent linear problem did not converge, whereas the primal linear problem did."
end
Expand All @@ -121,17 +123,17 @@
(x, info) = linsolve(f, b, x₀, algorithm, a₀, a₁)

if Δb isa AbstractZero
rhs = false * b
rhs = zerovector(b)

Check warning on line 126 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L126

Added line #L126 was not covered by tests
else
rhs = (1 - Δa₁) * Δb
rhs = scale(Δb, (1 - Δa₁))

Check warning on line 128 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L128

Added line #L128 was not covered by tests
end
if !iszero(Δa₀)
rhs = axpy!(-Δa₀, x, rhs)
rhs = add!!(rhs, x, -Δa₀)

Check warning on line 131 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L131

Added line #L131 was not covered by tests
end
if !(Δf isa AbstractZero)
rhs = axpy!(-a₁, frule_via_ad(config, (Δf, ZeroTangent()), f, x), rhs)
rhs = add!!(rhs, frule_via_ad(config, (Δf, ZeroTangent()), f, x), -a₀)

Check warning on line 134 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L134

Added line #L134 was not covered by tests
end
(Δx, forward_info) = linsolve(f, rhs, false * rhs, algorithm, a₀, a₁)
(Δx, forward_info) = linsolve(f, rhs, zerovector(rhs), algorithm, a₀, a₁)

Check warning on line 136 in src/adrules/linsolve.jl

View check run for this annotation

Codecov / codecov/patch

src/adrules/linsolve.jl#L136

Added line #L136 was not covered by tests
if info.converged > 0 && forward_info.converged == 0
@warn "The tangent linear problem did not converge, whereas the primal linear problem did."
end
Expand Down
2 changes: 1 addition & 1 deletion src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ apply(f, x) = f(x)
function apply(operator, x, α₀, α₁)
y = apply(operator, x)
if α₀ != zero(α₀) || α₁ != one(α₁)
axpby!(α₀, x, α₁, y)
y = add!!(y, x, α₀, α₁)
end

return y
Expand Down
6 changes: 3 additions & 3 deletions src/dense/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ end

function _rmul!(b::OrthonormalBasis, G::Givens)
q1, q2 = b[G.i1], b[G.i2]
q1old = mul!(similar(q1), q1, true)
q1 = axpby!(-conj(G.s), q2, G.c, q1)
q2 = axpby!(G.s, q1old, G.c, q2)
q1′ = add(q1, q2, -conj(G.s), G.c)
q2′ = add!!(q2, q1, G.s, G.c)
b[G.i1], b[G.i2] = q1′, q2
return b
end

Expand Down
6 changes: 3 additions & 3 deletions src/dense/reflector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ function LinearAlgebra.rmul!(b::OrthonormalBasis, H::Householder)
r = H.r
β = H.β
iszero(β) && return b
w = similar(b[first(r)])
w = zerovector(b[first(r)])
@inbounds begin
unproject!(w, b, v, 1, 0, r)
rank1update!(b, w, v, -β, 1, r)
w = unproject!!(w, b, v, 1, 0, r)
b = rank1update!(b, w, v, -β, 1, r)
end
return b
end
6 changes: 3 additions & 3 deletions src/eigsolve/arnoldi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ function schursolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi)
[B * u for u in cols(U, 1:howmany)]
end
residuals = let r = residual(fact)
[last(u) * r for u in cols(U, 1:howmany)]
[scale(r, last(u)) for u in cols(U, 1:howmany)]
end
normresiduals = [normres(fact) * abs(last(u)) for u in cols(U, 1:howmany)]

Expand Down Expand Up @@ -145,7 +145,7 @@ function eigsolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi)
[B * v for v in cols(V)]
end
residuals = let r = residual(fact)
[last(v) * r for v in cols(V)]
[scale(r, last(v)) for v in cols(V)]
end
normresiduals = [normres(fact) * abs(last(v)) for v in cols(V)]

Expand Down Expand Up @@ -271,7 +271,7 @@ function _schursolve(A, x₀, howmany::Int, which::Selector, alg::Arnoldi)
B = basis(fact)
basistransform!(B, view(U, :, 1:keep))
r = residual(fact)
B[keep + 1] = rmul!(r, 1 / normres(fact))
B[keep + 1] = scale!!(r, 1 / normres(fact))

# Shrink Arnoldi factorization
fact = shrink!(fact, keep)
Expand Down
Loading
Loading