Skip to content

Commit

Permalink
Move FiniteDifferences support to package extension (#132)
Browse files Browse the repository at this point in the history
* Move FiniteDifferences support to package extension

* Refactor implementation
  • Loading branch information
lkdvos authored Jun 24, 2024
1 parent b026cf2 commit 9787d05
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"

[extensions]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitFiniteDifferencesExt = "FiniteDifferences"

[compat]
Aqua = "0.6, 0.7, 0.8"
Expand All @@ -27,8 +29,8 @@ ChainRulesTestUtils = "1"
Combinatorics = "1"
FiniteDifferences = "0.12"
HalfIntegers = "1"
LinearAlgebra = "1"
LRUCache = "1.0.2"
LinearAlgebra = "1"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2"
Expand Down
30 changes: 30 additions & 0 deletions ext/TensorKitFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module TensorKitFiniteDifferencesExt

using TensorKit
using TensorKit: sqrtdim, isqrtdim
using VectorInterface: scale!
using FiniteDifferences

function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap}
vec, from_vec = to_vec(t.data)
return vec, x -> T(from_vec(x), codomain(t), domain(t))
end
function FiniteDifferences.to_vec(t::AbstractTensorMap)
# convert to vector of vectors to make use of existing functionality
vec_of_vecs = [b * sqrtdim(c) for (c, b) in blocks(t)]
vec, back = FiniteDifferences.to_vec(vec_of_vecs)

function from_vec(x)
t′ = similar(t)
xvec_of_vecs = back(x)
for (i, (c, b)) in enumerate(blocks(t′))
scale!(b, xvec_of_vecs[i], isqrtdim(c))
end
return t′
end

return vec, from_vec
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

end
25 changes: 0 additions & 25 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ChainRulesCore
using ChainRulesTestUtils
using Random
using FiniteDifferences
using LinearAlgebra

const _repartition = @static if isdefined(Base, :get_extension)
Expand All @@ -21,30 +20,6 @@ function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap,
ChainRulesTestUtils.@test_msg msg isapprox(b, block(expected, c); kwargs...)
end
end
function FiniteDifferences.to_vec(t::T) where {T<:TensorKit.TrivialTensorMap}
vec, from_vec = to_vec(t.data)
return vec, x -> T(from_vec(x), codomain(t), domain(t))
end
function FiniteDifferences.to_vec(t::AbstractTensorMap)
vec = mapreduce(vcat, blocks(t); init=scalartype(t)[]) do (c, b)
return reshape(b, :) .* sqrt(dim(c))
end
vec_real = scalartype(t) <: Real ? vec : collect(reinterpret(real(scalartype(t)), vec))

function from_vec(x_real)
x = scalartype(t) <: Real ? x_real : reinterpret(scalartype(t), x_real)
t′ = similar(t)
ctr = 0
for (c, b) in blocks(t′)
n = length(b)
copyto!(b, reshape(view(x, ctr .+ (1:n)), size(b)) ./ sqrt(dim(c)))
ctr += n
end
return t′
end
return vec_real, from_vec
end
FiniteDifferences.to_vec(t::TensorKit.AdjointTensorMap) = to_vec(copy(t))

function _randomize!(a::TensorMap)
for b in values(blocks(a))
Expand Down

0 comments on commit 9787d05

Please sign in to comment.