Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define svd #20

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.3.7"
version = "0.3.8"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
91 changes: 77 additions & 14 deletions src/tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LinearAlgebra: LinearAlgebra, qr
using TensorAlgebra: TensorAlgebra, blockedperm, contract, contract!, fusedims, splitdims
using LinearAlgebra: LinearAlgebra
using TensorAlgebra:
TensorAlgebra, blockedperm, contract, contract!, fusedims, qr, splitdims, svd
using TensorAlgebra.BaseExtensions: BaseExtensions

function TensorAlgebra.contract!(
Expand Down Expand Up @@ -35,6 +36,22 @@
return contract(a1, a2)
end

# Left associative fold/reduction.
# Circumvent Base definitions:
# ```julia
# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix)
# *(A::AbstractMatrix, B::AbstractMatrix, C::AbstractMatrix, D::AbstractMatrix)
# ```
# that optimize matrix multiplication sequence.
function Base.:*(

Check warning on line 46 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L46

Added line #L46 was not covered by tests
a1::AbstractNamedDimsArray,
a2::AbstractNamedDimsArray,
a3::AbstractNamedDimsArray,
a_rest::AbstractNamedDimsArray...,
)
return *(*(a1, a2), a3, a_rest...)

Check warning on line 52 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L52

Added line #L52 was not covered by tests
end

function LinearAlgebra.mul!(
a_dest::AbstractNamedDimsArray,
a1::AbstractNamedDimsArray,
Expand Down Expand Up @@ -99,32 +116,78 @@
return nameddims(a_split, names_split)
end

function LinearAlgebra.qr(
function TensorAlgebra.qr(

Check warning on line 119 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L119

Added line #L119 was not covered by tests
a::AbstractNamedDimsArray,
nameddimsindices_codomain,
nameddimsindices_domain;
positive=nothing,
)
@assert isnothing(positive) || !positive
# TODO: This should be `TensorAlgebra.qr` rather than overloading `LinearAlgebra.qr`.
# TODO: Don't require wrapping in `Tuple`.
q, r = qr(
q_unnamed, r_unnamed = qr(

Check warning on line 126 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L126

Added line #L126 was not covered by tests
unname(a),
Tuple(nameddimsindices(a)),
Tuple(to_nameddimsindices(a, nameddimsindices_codomain)),
Tuple(to_nameddimsindices(a, nameddimsindices_domain)),
nameddimsindices(a),
to_nameddimsindices(a, nameddimsindices_codomain),
to_nameddimsindices(a, nameddimsindices_domain),
)
name_q = randname(dimnames(a, 1))
name_r = name_q
namedindices_q = named(last(axes(q_unnamed)), name_q)
namedindices_r = named(first(axes(r_unnamed)), name_r)
nameddimsindices_q = (

Check warning on line 136 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L132-L136

Added lines #L132 - L136 were not covered by tests
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_q
)
name_qr = randname(nameddimsindices(a)[1])
nameddimsindices_q = (to_nameddimsindices(a, nameddimsindices_codomain)..., name_qr)
nameddimsindices_r = (name_qr, to_nameddimsindices(a, nameddimsindices_domain)...)
return nameddims(q, nameddimsindices_q), nameddims(r, nameddimsindices_r)
nameddimsindices_r = (namedindices_r, to_nameddimsindices(a, nameddimsindices_domain)...)
q = nameddims(q_unnamed, nameddimsindices_q)
r = nameddims(r_unnamed, nameddimsindices_r)
return q, r

Check warning on line 142 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L139-L142

Added lines #L139 - L142 were not covered by tests
end

function LinearAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
function TensorAlgebra.qr(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)

Check warning on line 145 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L145

Added line #L145 was not covered by tests
return qr(
a,
nameddimsindices_codomain,
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
kwargs...,
)
end

function LinearAlgebra.qr(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.qr(a, args...; kwargs...)

Check warning on line 155 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end

function TensorAlgebra.svd(

Check warning on line 158 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L158

Added line #L158 was not covered by tests
a::AbstractNamedDimsArray, nameddimsindices_codomain, nameddimsindices_domain
)
u_unnamed, s_unnamed, v_unnamed = svd(

Check warning on line 161 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L161

Added line #L161 was not covered by tests
unname(a),
nameddimsindices(a),
to_nameddimsindices(a, nameddimsindices_codomain),
to_nameddimsindices(a, nameddimsindices_domain),
)
name_u = randname(dimnames(a, 1))
name_v = randname(dimnames(a, 1))
namedindices_u = named(last(axes(u_unnamed)), name_u)
namedindices_v = named(first(axes(v_unnamed)), name_v)
nameddimsindices_u = (

Check warning on line 171 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L167-L171

Added lines #L167 - L171 were not covered by tests
to_nameddimsindices(a, nameddimsindices_codomain)..., namedindices_u
)
nameddimsindices_s = (namedindices_u, namedindices_v)
nameddimsindices_v = (namedindices_v, to_nameddimsindices(a, nameddimsindices_domain)...)
u = nameddims(u_unnamed, nameddimsindices_u)
s = nameddims(s_unnamed, nameddimsindices_s)
v = nameddims(v_unnamed, nameddimsindices_v)
return u, s, v

Check warning on line 179 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L174-L179

Added lines #L174 - L179 were not covered by tests
end

function TensorAlgebra.svd(a::AbstractNamedDimsArray, nameddimsindices_codomain; kwargs...)
return svd(

Check warning on line 183 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
a,
nameddimsindices_codomain,
setdiff(nameddimsindices(a), to_nameddimsindices(a, nameddimsindices_codomain));
kwargs...,
)
end

function LinearAlgebra.svd(a::AbstractNamedDimsArray, args...; kwargs...)
return TensorAlgebra.svd(a, args...; kwargs...)

Check warning on line 192 in src/tensoralgebra.jl

View check run for this annotation

Codecov / codecov/patch

src/tensoralgebra.jl#L191-L192

Added lines #L191 - L192 were not covered by tests
end
29 changes: 22 additions & 7 deletions test/basics/test_tensoralgebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearAlgebra: qr
using LinearAlgebra: qr, svd
using NamedDimsArrays: namedoneto, dename
using TensorAlgebra: TensorAlgebra, contract, fusedims, splitdims
using Test: @test, @testset, @test_broken
Expand Down Expand Up @@ -47,15 +47,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))

na = randn(elt, i, j)
a = randn(elt, i, j)
# TODO: Should this be allowed?
# TODO: Add support for specifying new name.
q, r = qr(na, (i,))
@test q * r ≈ na
q, r = qr(a, (i,))
@test q * r ≈ a

na = randn(elt, i, j, k, l)
a = randn(elt, i, j, k, l)
# TODO: Add support for specifying new name.
q, r = qr(a, (i, k), (j, l))
@test q * r ≈ a
end
@testset "svd" begin
dims = (2, 2, 2, 2)
i, j, k, l = namedoneto.(dims, ("i", "j", "k", "l"))

a = randn(elt, i, j)
# TODO: Should this be allowed?
# TODO: Add support for specifying new name.
u, s, v = svd(a, (i,))
@test u * s * v ≈ a

a = randn(elt, i, j, k, l)
# TODO: Add support for specifying new name.
q, r = qr(na, (i, k), (j, l))
@test contract(q, r) ≈ na
u, s, v = svd(a, (i, k), (j, l))
@test u * s * v ≈ a
end
end
Loading