Skip to content

Commit

Permalink
Rewrite BMHTF interaction interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Sep 5, 2022
1 parent 8576bf8 commit 53368cc
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 112 deletions.
47 changes: 19 additions & 28 deletions src/core/nonbonded_interactions/bmhtf-nacl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,35 +25,26 @@
#include "bmhtf-nacl.hpp"

#ifdef BMHTF_NACL
#include "interactions.hpp"
#include "nonbonded_interactions/nonbonded_interaction_data.hpp"

#include <utils/constants.hpp>

int BMHTF_set_params(int part_type_a, int part_type_b, double A, double B,
double C, double D, double sig, double cut) {
double shift, dist2, pw6;
IA_parameters *data = get_ia_param_safe(part_type_a, part_type_b);

if (!data)
return ES_ERROR;

dist2 = cut * cut;
pw6 = dist2 * dist2 * dist2;
shift = -(A * exp(B * (sig - cut)) - C / pw6 - D / pw6 / dist2);

data->bmhtf.A = A;
data->bmhtf.B = B;
data->bmhtf.C = C;
data->bmhtf.D = D;
data->bmhtf.sig = sig;
data->bmhtf.cut = cut;
data->bmhtf.computed_shift = shift;

/* broadcast interaction parameters */
mpi_bcast_ia_params(part_type_a, part_type_b);

return ES_OK;
#include <utils/math/int_pow.hpp>

#include <stdexcept>

BMHTF_Parameters::BMHTF_Parameters(double a, double b, double c, double d,
double sig, double cutoff)
: A{a}, B{b}, C{c}, D{d}, sig{sig}, cut{cutoff} {
if (a < 0.) {
throw std::domain_error("BMHTF parameter 'a' has to be >= 0");
}
if (c < 0.) {
throw std::domain_error("BMHTF parameter 'c' has to be >= 0");
}
if (d < 0.) {
throw std::domain_error("BMHTF parameter 'd' has to be >= 0");
}
computed_shift = C / Utils::int_pow<6>(cut) + D / Utils::int_pow<8>(cut) -
A * std::exp(B * (sig - cut));
}

#endif
#endif // BMHTF_NACL
6 changes: 1 addition & 5 deletions src/core/nonbonded_interactions/bmhtf-nacl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,10 @@

#include "nonbonded_interaction_data.hpp"

#include <utils/Vector.hpp>
#include <utils/math/int_pow.hpp>

#include <cmath>

int BMHTF_set_params(int part_type_a, int part_type_b, double A, double B,
double C, double D, double sig, double cut);

/** Calculate BMHTF force factor */
inline double BMHTF_pair_force_factor(IA_parameters const &ia_params,
double dist) {
Expand All @@ -67,5 +63,5 @@ inline double BMHTF_pair_energy(IA_parameters const &ia_params, double dist) {
return 0.0;
}

#endif
#endif // BMHTF_NACL
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ static double recalc_maximal_cutoff(const IA_parameters &data) {
#endif

#ifdef BMHTF_NACL
max_cut_current = std::max(max_cut_current, data.bmhtf.cut);
max_cut_current = std::max(max_cut_current, data.bmhtf.max_cutoff());
#endif

#ifdef MORSE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ struct BMHTF_Parameters {
double sig = 0.0;
double cut = INACTIVE_CUTOFF;
double computed_shift = 0.0;
BMHTF_Parameters() = default;
BMHTF_Parameters(double A, double B, double C, double D, double sig,
double cut);
double max_cutoff() const { return cut; }
};

/** Morse potential */
Expand Down
17 changes: 0 additions & 17 deletions src/python/espressomd/interactions.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@ cdef extern from "nonbonded_interactions/nonbonded_interaction_data.hpp":
int n
double k0

cdef struct BMHTF_Parameters:
double A
double B
double C
double D
double sig
double cut
double computed_shift

cdef struct Morse_Parameters:
double eps
double alpha
Expand Down Expand Up @@ -114,8 +105,6 @@ cdef extern from "nonbonded_interactions/nonbonded_interaction_data.hpp":

SmoothStep_Parameters smooth_step

BMHTF_Parameters bmhtf

Morse_Parameters morse

Buckingham_Parameters buckingham
Expand Down Expand Up @@ -149,12 +138,6 @@ IF SMOOTH_STEP:
double d, int n, double eps,
double k0, double sig,
double cut)
IF BMHTF_NACL:
cdef extern from "nonbonded_interactions/bmhtf-nacl.hpp":
int BMHTF_set_params(int part_type_a, int part_type_b,
double A, double B, double C,
double D, double sig, double cut)

IF MORSE:
cdef extern from "nonbonded_interactions/morse.hpp":
int morse_set_params(int part_type_a, int part_type_b,
Expand Down
88 changes: 27 additions & 61 deletions src/python/espressomd/interactions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -938,51 +938,41 @@ IF SMOOTH_STEP == 1:

IF BMHTF_NACL == 1:

cdef class BMHTFInteraction(NonBondedInteraction):
@script_interface_register
class BMHTFInteraction(NewNonBondedInteraction):
"""BMHTF interaction.
def validate_params(self):
"""Check that parameters are valid.
Methods
-------
set_params()
Set or update parameters for the interaction.
Parameters marked as required become optional once the
interaction has been activated for the first time;
subsequent calls to this method update the existing values.
"""
if self._params["a"] < 0:
raise ValueError("BMHTF a has to be >=0")
if self._params["c"] < 0:
raise ValueError("BMHTF c has to be >=0")
if self._params["d"] < 0:
raise ValueError("BMHTF d has to be >=0")
if self._params["cutoff"] < 0:
raise ValueError("BMHTF cutoff has to be >=0")
Parameters
----------
a : :obj:`float`
Magnitude of exponential part of the interaction.
b : :obj:`float`
Exponential factor of the interaction.
c : :obj:`float`
Magnitude of the term decaying with the sixth power of r.
d : :obj:`float`
Magnitude of the term decaying with the eighth power of r.
sig : :obj:`float`
Shift in the exponent.
cutoff : :obj:`float`
Cutoff distance of the interaction.
"""

def _get_params_from_es_core(self):
cdef IA_parameters * ia_params
ia_params = get_ia_param_safe(self._part_types[0],
self._part_types[1])
return {
"a": ia_params.bmhtf.A,
"b": ia_params.bmhtf.B,
"c": ia_params.bmhtf.C,
"d": ia_params.bmhtf.D,
"sig": ia_params.bmhtf.sig,
"cutoff": ia_params.bmhtf.cut,
}
_so_name = "Interactions::InteractionBMHTF"

def is_active(self):
"""Check if interaction is active.
"""
return (self._params["a"] > 0) and (
self._params["c"] > 0) and (self._params["d"] > 0)

def _set_params_in_es_core(self):
if BMHTF_set_params(self._part_types[0],
self._part_types[1],
self._params["a"],
self._params["b"],
self._params["c"],
self._params["d"],
self._params["sig"],
self._params["cutoff"]):
raise Exception("Could not set BMHTF parameters")
return self.a > 0. and self.c > 0. and self.d > 0.

def default_params(self):
"""Python dictionary of default parameters.
Expand All @@ -996,28 +986,6 @@ IF BMHTF_NACL == 1:
"""
return "BMHTF"

def set_params(self, **kwargs):
"""
Set parameters for the BMHTF interaction.
Parameters
----------
a : :obj:`float`
Magnitude of exponential part of the interaction.
b : :obj:`float`
Exponential factor of the interaction.
c : :obj:`float`
Magnitude of the term decaying with the sixth power of r.
d : :obj:`float`
Magnitude of the term decaying with the eighth power of r.
sig : :obj:`float`
Shift in the exponent.
cutoff : :obj:`float`
Cutoff distance of the interaction.
"""
super().set_params(**kwargs)

def valid_keys(self):
"""All parameters that can be set.
Expand Down Expand Up @@ -1438,8 +1406,6 @@ class NonBondedInteractionHandle(ScriptInterfaceHelper):
self.soft_sphere = SoftSphereInteraction(_type1, _type2)
IF SMOOTH_STEP:
self.smooth_step = SmoothStepInteraction(_type1, _type2)
IF BMHTF_NACL:
self.bmhtf = BMHTFInteraction(_type1, _type2)
IF MORSE:
self.morse = MorseInteraction(_type1, _type2)
IF BUCKINGHAM:
Expand Down
38 changes: 38 additions & 0 deletions src/script_interface/interactions/NonBondedInteraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,34 @@ class InteractionGaussian
};
#endif // GAUSSIAN

#ifdef BMHTF_NACL
class InteractionBMHTF
: public InteractionPotentialInterface<::BMHTF_Parameters> {
protected:
CoreInteraction IA_parameters::*get_ptr_offset() const override {
return &::IA_parameters::bmhtf;
}

public:
InteractionBMHTF() {
add_parameters({
make_autoparameter(&CoreInteraction::A, "a"),
make_autoparameter(&CoreInteraction::B, "b"),
make_autoparameter(&CoreInteraction::C, "c"),
make_autoparameter(&CoreInteraction::D, "d"),
make_autoparameter(&CoreInteraction::sig, "sig"),
make_autoparameter(&CoreInteraction::cut, "cutoff"),
});
}

void make_new_instance(VariantMap const &params) override {
m_ia_si = make_shared_from_args<CoreInteraction, double, double, double,
double, double, double>(
params, "a", "b", "c", "d", "sig", "cutoff");
}
};
#endif // BMHTF_NACL

class NonBondedInteractionHandle
: public AutoParameters<NonBondedInteractionHandle> {
std::array<int, 2> m_types = {-1, -1};
Expand All @@ -382,6 +410,9 @@ class NonBondedInteractionHandle
#ifdef GAUSSIAN
std::shared_ptr<InteractionGaussian> m_gaussian;
#endif
#ifdef BMHTF_NACL
std::shared_ptr<InteractionBMHTF> m_bmhtf;
#endif

template <class T>
auto make_autoparameter(std::shared_ptr<T> &member, const char *key) const {
Expand Down Expand Up @@ -420,6 +451,9 @@ class NonBondedInteractionHandle
#endif
#ifdef GAUSSIAN
make_autoparameter(m_gaussian, "gaussian"),
#endif
#ifdef BMHTF_NACL
make_autoparameter(m_bmhtf, "bmhtf"),
#endif
});
}
Expand Down Expand Up @@ -486,6 +520,10 @@ class NonBondedInteractionHandle
#ifdef GAUSSIAN
set_member<InteractionGaussian>(
m_gaussian, "gaussian", "Interactions::InteractionGaussian", params);
#endif
#ifdef BMHTF_NACL
set_member<InteractionBMHTF>(m_bmhtf, "bmhtf",
"Interactions::InteractionBMHTF", params);
#endif
}

Expand Down
3 changes: 3 additions & 0 deletions src/script_interface/interactions/initialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void initialize(Utils::Factory<ObjectHandle> *om) {
#ifdef GAUSSIAN
om->register_new<InteractionGaussian>("Interactions::InteractionGaussian");
#endif
#ifdef BMHTF_NACL
om->register_new<InteractionBMHTF>("Interactions::InteractionBMHTF");
#endif
}
} // namespace Interactions
} // namespace ScriptInterface
8 changes: 8 additions & 0 deletions testsuite/python/interactions_non-bonded_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def test_gaussian_exceptions(self):
("eps", "sig")
)

@utx.skipIfMissingFeatures("BMHTF_NACL")
def test_bmhtf_exceptions(self):
self.check_potential_exceptions(
espressomd.interactions.BMHTFInteraction,
{"a": 3., "b": 2., "c": 1., "d": 4., "sig": 0.13, "cutoff": 1.2},
("a", "c", "d")
)


if __name__ == "__main__":
ut.main()

0 comments on commit 53368cc

Please sign in to comment.