Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ForwardDiff for lattice strain DFPT response #1054

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
6 changes: 4 additions & 2 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ function LinearAlgebra.mul!(Hψ, H::Hamiltonian, ψ)
end
end
# need `deepcopy` here to copy the elements of the array of arrays ψ (not just pointers)
Base.:*(H::Hamiltonian, ψ) = mul!(deepcopy(ψ), H, ψ)
function Base.:*(H::Hamiltonian, ψ)
result = ψ * one(eltype(H.basis)) # Includes type promotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the same thing as a deepcopy, be careful !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is allocating new memory recursively since the multiplication-by-scalar is broadcasted automatically for abstract arrays. I now expanded the comment to clarify the intent of allocation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my edification, what was wrong with the deepcopy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deepcopy was missing a type promotion to be able to hold the result, for the case when the Hamiltonian contains Duals but psi is Float64

mul!(result, H, ψ)
end

# Loop through bands, IFFT to get ψ in real space, loop through terms, FFT and accumulate into Hψ
# For the common DftHamiltonianBlock there is an optimized version below
Expand Down
27 changes: 14 additions & 13 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,32 @@
basis_primal = construct_value(basis_dual)
scfres = self_consistent_field(basis_primal; kwargs...)

## Compute external perturbation (contained in ham_dual) and from matvec with bands
# Compute explicit density perturbation (including strain)
ρ_basis = compute_density(basis_dual, scfres.ψ, scfres.occupation)

Check warning on line 161 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L161

Added line #L161 was not covered by tests

# Compute external perturbation (contained in ham_dual)
Hψ_dual = let
occupation_dual = [T.(occk) for occk in scfres.occupation]
ψ_dual = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in scfres.ψ]
ρ_dual = compute_density(basis_dual, ψ_dual, occupation_dual)
εF_dual = T(scfres.εF) # Only needed for entropy term
eigenvalues_dual = [T.(εk) for εk in scfres.eigenvalues]
ham_dual = energy_hamiltonian(basis_dual, ψ_dual, occupation_dual;
ρ=ρ_dual, eigenvalues=eigenvalues_dual,
εF=εF_dual).ham
ham_dual * ψ_dual
ham_dual = energy_hamiltonian(basis_dual, scfres.ψ, scfres.occupation;

Check warning on line 165 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L165

Added line #L165 was not covered by tests
ρ=ρ_basis, scfres.eigenvalues,
scfres.εF).ham
ham_dual * scfres.ψ

Check warning on line 168 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L168

Added line #L168 was not covered by tests
end

## Implicit differentiation
# Implicit differentiation
response.verbose && println("Solving response problem")
δresults = ntuple(ForwardDiff.npartials(T)) do α
δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual]
solve_ΩplusK_split(scfres, -δHextψ; tol=last(scfres.history_Δρ), response.verbose)
end

## Convert and combine
# Convert and combine
DT = Dual{ForwardDiff.tagtype(T)}
ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk...
map(ψk, δψk...) do ψnk, δψnk...
Complex(DT(real(ψnk), real.(δψnk)),
DT(imag(ψnk), imag.(δψnk)))
end
end
ρ = map((ρi, δρi...) -> DT(ρi, δρi), scfres.ρ, getfield.(δresults, :δρ)...)
eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk...
map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)
end
Expand All @@ -194,6 +191,10 @@
end
εF = DT(scfres.εF, getfield.(δresults, :δεF)...)

# Add contributions from the basis (ρ_basis) and implicit diff (ρ_response)
ρ_response = map((ρi, δρi...) -> DT(ρi, δρi), zero(scfres.ρ), getfield.(δresults, :δρ)...)
ρ = ρ_basis + ρ_response

Check warning on line 196 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L195-L196

Added lines #L195 - L196 were not covered by tests

# TODO Could add δresults[α].δVind the dual part of the total local potential in ham_dual
# and in this way return a ham that represents also the total change in Hamiltonian

Expand Down
42 changes: 42 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,48 @@
end
end

@testitem "Strain sensitivity using ForwardDiff" #=
=# tags=[:dont_test_mpi] setup=[TestCases] begin
using DFTK
using ForwardDiff
using LinearAlgebra
using ComponentArrays
using PseudoPotentialData
aluminium = TestCases.aluminium
Ecut = 5
kgrid = [2, 2, 2]
model = model_DFT(aluminium.lattice, aluminium.atoms, aluminium.positions;
functionals=LDA(), temperature=1e-2, smearing=Smearing.Gaussian())
basis = PlaneWaveBasis(model; Ecut, kgrid)
nbandsalg = FixedBands(; n_bands_converge=10)
response = ResponseOptions(; verbose=true)

function compute_properties(ε)
model_strained = Model(model; lattice=(1 + ε) * model.lattice)
basis = PlaneWaveBasis(model_strained; Ecut, kgrid)
scfres = self_consistent_field(basis; tol=1e-10, nbandsalg, response)
ComponentArray(
eigenvalues=stack([ev[1:10] for ev in scfres.eigenvalues]),
ρ=scfres.ρ,
energies=collect(values(scfres.energies)),
εF=scfres.εF,
occupation=reduce(vcat, scfres.occupation),
)
end

dx = ForwardDiff.derivative(compute_properties, 0.)

h = 1e-4
x1 = compute_properties(-h)
x2 = compute_properties(+h)
dx_findiff = (x2 - x1) / 2h
@test norm(dx.ρ - dx_findiff.ρ) * sqrt(basis.dvol) < 1e-6
@test maximum(abs, dx.eigenvalues - dx_findiff.eigenvalues) < 1e-6
@test maximum(abs, dx.energies - dx_findiff.energies) < 1e-5
@test dx.εF - dx_findiff.εF < 1e-6
@test maximum(abs, dx.occupation - dx_findiff.occupation) < 1e-5
end

@testitem "scfres PSP sensitivity using ForwardDiff" #=
=# tags=[:dont_test_mpi] setup=[TestCases] begin
using DFTK
Expand Down
Loading