Skip to content

Commit

Permalink
GBn2 force don't broadcast 2
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Oct 2, 2024
1 parent 2f80033 commit 4d5478b
Showing 1 changed file with 24 additions and 38 deletions.
62 changes: 24 additions & 38 deletions src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -857,31 +857,17 @@ function gbsa_born_kernel!(Is, I_grads, coords_var, offset_radii_var, scaled_off
return nothing
end

# Store the results of the ij broadcasts during force calculation
struct ForceLoopResult1{T, V}
bi::T
bj::T
fi::V
fj::V
end

get_bi(r::ForceLoopResult1) = r.bi
get_bj(r::ForceLoopResult1) = r.bj

get_fi(r::ForceLoopResult1) = r.fi
get_fj(r::ForceLoopResult1) = r.fj

function gb_force_loop_1(coord_i, coord_j, i, j, charge_i, charge_j, Bi, Bj, dist_cutoff,
factor_solute, factor_solvent, kappa, boundary)
if j < i
zero_force = zero(factor_solute ./ coord_i .^ 2)
return ForceLoopResult1(zero_force[1], zero_force[1], zero_force, zero_force)
return zero_force[1], zero_force[1], zero_force, zero_force
end
dr = vector(coord_i, coord_j, boundary)
r2 = sum(abs2, dr)
if !iszero_value(dist_cutoff) && r2 > dist_cutoff^2
zero_force = zero(factor_solute ./ coord_i .^ 2)
return ForceLoopResult1(zero_force[1], zero_force[1], zero_force, zero_force)
return zero_force[1], zero_force[1], zero_force, zero_force
end
alpha2_ij = Bi * Bj
D = r2 / (4 * alpha2_ij)
Expand All @@ -903,11 +889,10 @@ function gb_force_loop_1(coord_i, coord_j, i, j, charge_i, charge_j, Bi, Bj, dis
fdr = dr * dGpol_dr
change_fs_i = fdr
change_fs_j = -fdr
return ForceLoopResult1(change_born_force_i, change_born_force_j,
change_fs_i, change_fs_j)
return change_born_force_i, change_born_force_j, change_fs_i, change_fs_j
else
zero_force = zero(factor_solute ./ coord_i .^ 2)
return ForceLoopResult1(change_born_force_i, zero_force[1], zero_force, zero_force)
return change_born_force_i, zero_force[1], zero_force, zero_force
end
end

Expand Down Expand Up @@ -935,22 +920,22 @@ end

function forces_gbsa(sys, inter, Bs, B_grads, I_grads, born_forces, atom_charges)
coords, boundary = sys.coords, sys.boundary
coords_i = @view coords[inter.is]
coords_j = @view coords[inter.js]
charges_i = @view atom_charges[inter.is]
charges_j = @view atom_charges[inter.js]
Bsi = @view Bs[inter.is]
Bsj = @view Bs[inter.js]
loop_res_1 = gb_force_loop_1.(coords_i, coords_j, inter.is, inter.js, charges_i,
charges_j, Bsi, Bsj, inter.dist_cutoff, inter.factor_solute,
inter.factor_solvent, inter.kappa, (boundary,))
born_forces_1 = born_forces .+ dropdims(sum(get_bi.(loop_res_1); dims=2); dims=2) .+
dropdims(sum(get_bj.(loop_res_1); dims=1); dims=1)
fs = dropdims(sum(get_fi.(loop_res_1); dims=2); dims=2) .+
dropdims(sum(get_fj.(loop_res_1); dims=1); dims=1)
born_forces_1 = copy(born_forces)
fs = ustrip_vec.(zero(coords)) * sys.force_units
@inbounds for i in eachindex(sys)
for j in eachindex(sys)
bi, bj, fi, fj = gb_force_loop_1(
coords[i], coords[j], i, j, atom_charges[i], atom_charges[j], Bs[i], Bs[j],
inter.dist_cutoff, inter.factor_solute, inter.factor_solvent, inter.kappa, boundary,
)
born_forces_1[i] += bi
born_forces_1[j] += bj
fs[i] = fs[i] .+ fi
fs[j] = fs[j] .+ fj
end
end

born_forces_2 = born_forces_1 .* (Bs .^ 2) .* B_grads

@inbounds for i in eachindex(sys)
for j in eachindex(sys)
f = gb_force_loop_2(coords[i], coords[j], born_forces_2[i], I_grads[i, j],
Expand Down Expand Up @@ -1199,9 +1184,10 @@ function AtomsCalculators.potential_energy(sys::System{<:Any, true}, inter::Abst
charges_j = @view atom_charges[inter.js]
Bsi = @view Bs[inter.is]
Bsj = @view Bs[inter.js]
return sum(gb_energy_loop.(coords_i, coords_j, inter.is, inter.js, charges_i,
charges_j, Bsi, Bsj, inter.oris, inter.dist_cutoff,
inter.factor_solute, inter.factor_solvent, inter.kappa,
inter.offset, inter.probe_radius, inter.sa_factor, inter.use_ACE,
(boundary,)))
return sum(gb_energy_loop.(
coords_i, coords_j, inter.is, inter.js, charges_i, charges_j, Bsi, Bsj,
inter.oris, inter.dist_cutoff, inter.factor_solute, inter.factor_solvent,
inter.kappa, inter.offset, inter.probe_radius, inter.sa_factor, inter.use_ACE,
(boundary,),
))
end

0 comments on commit 4d5478b

Please sign in to comment.