Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lu decomposition for Tensors #94

Merged
merged 10 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ValSplit = "0625e100-946b-11ec-09cd-6328dd093154"

Expand Down
8 changes: 8 additions & 0 deletions docs/src/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,11 @@ length(Tᵢⱼₖ)
```@docs
Tenet.contract(::Tensor, ::Tensor)
```

### Factorizations

```@docs
LinearAlgebra.svd(::Tensor)
LinearAlgebra.qr(::Tensor)
LinearAlgebra.lu(::Tensor)
```
160 changes: 107 additions & 53 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OMEinsum
using LinearAlgebra
using UUIDs: uuid4
using SparseArrays

# TODO test array container typevar on output
for op in [
Expand Down Expand Up @@ -79,89 +80,142 @@
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a))
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b))

function factorinds(tensor, left_inds, right_inds)
isdisjoint(left_inds, right_inds) ||
throw(ArgumentError("left ($left_inds) and right $(right_inds) indices must be disjoint"))

left_inds, right_inds =
isempty(left_inds) ? (setdiff(inds(tensor), right_inds), right_inds) :
isempty(right_inds) ? (left_inds, setdiff(inds(tensor), left_inds)) :
throw(ArgumentError("cannot set both left and right indices"))

all(!isempty, (left_inds, right_inds)) || throw(ArgumentError("no right-indices left in factorization"))
all(∈(inds(tensor)), left_inds ∪ right_inds) || throw(ArgumentError("indices must be in $(inds(tensor))"))

return left_inds, right_inds
end

LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

function LinearAlgebra.svd(t::Tensor; left_inds, kwargs...)
if isempty(left_inds)
throw(ErrorException("no left-indices in SVD factorization"))
elseif any(∉(inds(t)), left_inds)
# TODO better error exception and checks
throw(ErrorException("all left-indices must be in $(inds(t))"))
end
"""
LinearAlgebra.svd(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)

Perform SVD factorization on a tensor.

# Keyword arguments

right_inds = setdiff(inds(t), left_inds)
if isempty(right_inds)
# TODO better error exception and checks
throw(ErrorException("no right-indices in SVD factorization"))
end
- `left_inds`: left indices to be used in the SVD factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the SVD factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.svd(tensor::Tensor; left_inds = (), right_inds = (), virtualind = Symbol(uuid4()), kwargs...)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

virtualind ∉ inds(tensor) ||
throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute SVD
U, s, V = svd(data; kwargs...)

# tensorify results
U = reshape(U, ([size(t, ind) for ind in left_inds]..., size(U, 2)))
s = Diagonal(s)
Vt = reshape(V', (size(V', 1), [size(t, ind) for ind in right_inds]...))

vlind = Symbol(uuid4())
vrind = Symbol(uuid4())

U = Tensor(U, (left_inds..., vlind))
s = Tensor(s, (vlind, vrind))
Vt = Tensor(Vt, (vrind, right_inds...))
U = Tensor(reshape(U, left_sizes..., size(U, 2)), [left_inds..., virtualind])
s = Tensor(s, [virtualind])
Vt = Tensor(reshape(V, right_sizes..., size(V, 2)), [right_inds..., virtualind])

return U, s, Vt
end

LinearAlgebra.qr(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke qr(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

"""
LinearAlgebra.qr(t::Tensor, mode::Symbol = :reduced; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...
LinearAlgebra.qr(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)

Perform QR factorization on a tensor.

# Arguments

- `t::Tensor`: tensor to be factorized

# Keyword Arguments
# Keyword arguments

- `left_inds`: left indices to be used in the QR factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the QR factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
- `left_inds`: left indices to be used in the QR factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the QR factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.qr(t::Tensor; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...)
isdisjoint(left_inds, right_inds) ||
throw(ArgumentError("left ($left_inds) and right $(right_inds) indices must be disjoint"))

left_inds, right_inds =
isempty(left_inds) ? (setdiff(inds(t), right_inds), right_inds) :
isempty(right_inds) ? (left_inds, setdiff(inds(t), left_inds)) :
throw(ArgumentError("cannot set both left and right indices"))

all(!isempty, (left_inds, right_inds)) || throw(ArgumentError("no right-indices left in QR factorization"))
all(∈(inds(t)), left_inds ∪ right_inds) || throw(ArgumentError("indices must be in $(inds(t))"))

virtualind ∉ inds(t) || throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))
function LinearAlgebra.qr(
tensor::Tensor;
left_inds = (),
right_inds = (),
virtualind::Symbol = Symbol(uuid4()),
kwargs...,
)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

virtualind ∉ inds(tensor) ||
throw(ArgumentError("new virtual bond name ($virtualind) cannot be already be present"))

# permute array
tensor = permutedims(t, (left_inds..., right_inds...))
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute QR
F = qr(data; kwargs...)
Q, R = Matrix(F.Q), Matrix(F.R)

# tensorify results
Q = reshape(Q, ([size(t, ind) for ind in left_inds]..., size(Q, 2)))
R = reshape(R, (size(R, 1), [size(t, ind) for ind in right_inds]...))

Q = Tensor(Q, (left_inds..., virtualind))
R = Tensor(R, (virtualind, right_inds...))
Q = Tensor(reshape(Q, left_sizes..., size(Q, 2)), [left_inds..., virtualind])
R = Tensor(reshape(R, size(R, 1), right_sizes...), [virtualind, right_inds...])

return Q, R
end

LinearAlgebra.lu(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke lu(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

Check warning on line 176 in src/Numerics.jl

View check run for this annotation

Codecov / codecov/patch

src/Numerics.jl#L176

Added line #L176 was not covered by tests

"""
LinearAlgebra.lu(tensor::Tensor; left_inds, right_inds, virtualind, kwargs...)

Perform LU factorization on a tensor.

# Keyword arguments

- `left_inds`: left indices to be used in the LU factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the LU factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.lu(
tensor::Tensor;
left_inds = (),
right_inds = (),
virtualind = [Symbol(uuid4()), Symbol(uuid4())],
kwargs...,
)
left_inds, right_inds = factorinds(tensor, left_inds, right_inds)

i_pl, i_lu = virtualind
i_pl ∉ inds(tensor) || throw(ArgumentError("new virtual bond name ($i_pl) cannot be already be present"))
i_lu ∉ inds(tensor) || throw(ArgumentError("new virtual bond name ($i_lu) cannot be already be present"))

# permute array
left_sizes = map(Base.Fix1(size, tensor), left_inds)
right_sizes = map(Base.Fix1(size, tensor), right_inds)
tensor = permutedims(tensor, [left_inds..., right_inds...])
data = reshape(parent(tensor), prod(left_sizes), prod(right_sizes))

# compute LU
info = lu(data; kwargs...)
L = info.L
U = info.U

permutator = info.p
P = sparse(permutator, 1:length(permutator), fill(true, length(permutator)))

L = Tensor(L, [i_pl, i_lu])
U = Tensor(reshape(U, size(U, 1), right_sizes...), [i_lu, right_inds...])
P = Tensor(reshape(P, left_sizes..., size(L, 1)), [left_inds..., i_pl])

return L, U, P
end
2 changes: 1 addition & 1 deletion src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ Base.selectdim(t::Tensor, d::Symbol, i) = selectdim(t, dim(t, d), i)
Base.permutedims(t::Tensor, perm) = Tensor(permutedims(parent(t), perm), getindex.((inds(t),), perm))
Base.permutedims!(dest::Tensor, src::Tensor, perm) = permutedims!(parent(dest), parent(src), perm)

function Base.permutedims(t::Tensor{T,N}, perm::NTuple{N,Symbol}) where {T,N}
function Base.permutedims(t::Tensor{T}, perm::Base.AbstractVecOrTuple{Symbol}) where {T}
perm = map(i -> findfirst(==(i), inds(t)), perm)
permutedims(t, perm)
end
Expand Down
15 changes: 8 additions & 7 deletions src/Transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,20 +249,21 @@ function transform!(tn::AbstractTensorNetwork, config::SplitSimplification)
bipartitions = Iterators.flatten(combinations(inds, r) for r in 1:(length(inds)-1))
for bipartition in bipartitions
left_inds = collect(bipartition)
right_inds = setdiff(inds, left_inds)

# perform an SVD across the bipartition
u, s, v = svd(tensor; left_inds = left_inds)
rank_s = sum(diag(s) .> config.atol)
rank_s = sum(s .> config.atol)

if rank_s < length(s)
hyperindex = only(Tenet.inds(s))

if rank_s < size(s, 1)
# truncate data
u = view(u, Tenet.inds(s)[1] => 1:rank_s)
s = view(s, (idx -> idx => 1:rank_s).(Tenet.inds(s))...)
v = view(v, Tenet.inds(s)[2] => 1:rank_s)
u = view(u, hyperindex => 1:rank_s)
s = view(s, hyperindex => 1:rank_s)
v = view(v, hyperindex => 1:rank_s)

# replace the original tensor with factorization
tensor_l = u * s
tensor_l = contract(u, s, dims = Symbol[])
tensor_r = v

push!(tn, dropdims(tensor_l))
Expand Down
Loading
Loading