Skip to content

Commit

Permalink
Makes the rf and x_space implementation consistent.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahurta92 committed Jul 1, 2024
1 parent 2261ca7 commit 7c6687d
Showing 1 changed file with 34 additions and 65 deletions.
99 changes: 34 additions & 65 deletions src/apps/molresponse/response_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ namespace madness {
* @param num_orbitals
*/
response_space(World &world, size_t num_states, size_t num_orbitals)
: num_states(num_states), num_orbitals(num_orbitals), x(response_matrix(num_states)),
active(num_states) {
: num_states(num_states), num_orbitals(num_orbitals), x(response_matrix(num_states)), active(num_states) {
for (auto &state: x) { state = vector_real_function_3d(num_orbitals); }
reset_active();
//world.gop.fence();
Expand All @@ -122,8 +121,8 @@ namespace madness {
[[nodiscard]] response_space copy() const {
World &world = x[0][0].world();
response_space result(*this);
std::transform(x.begin(), x.end(), result.x.begin(),
[&world](const vector_real_function_3d &xi) { return madness::copy(world, xi, false); });
//copy each state
for (size_t i = 0; i < num_states; i++) { result.x[i] = madness::copy(world, x[i], false); }
world.gop.fence();
return result;
}
Expand All @@ -140,17 +139,14 @@ namespace madness {
vector_real_function_3d &operator[](size_t i) { return x.at(i); }
const vector_real_function_3d &operator[](size_t i) const { return x.at(i); }

friend auto
inplace_unary_apply(response_space &A,
const std::function<void(vector_real_function_3d &)> &func) {
friend auto inplace_unary_apply(response_space &A, const std::function<void(vector_real_function_3d &)> &func) {
auto &world = A.x[0][0].world();
for (auto &i: A.active) { func(A.x[i]); }
world.gop.fence();
}

friend auto oop_unary_apply(
const response_space &A,
const std::function<vector_real_function_3d(const vector_real_function_3d &)> &func)
friend auto oop_unary_apply(const response_space &A,
const std::function<vector_real_function_3d(const vector_real_function_3d &)> &func)
-> response_space {
auto result = A.copy();
auto &world = result.x[0][0].world();
Expand All @@ -159,10 +155,9 @@ namespace madness {
return result;
}

friend auto
binary_apply(const response_space &A, const response_space &B,
const std::function<vector_real_function_3d(vector_real_function_3d,
vector_real_function_3d)> &func)
friend auto binary_apply(
const response_space &A, const response_space &B,
const std::function<vector_real_function_3d(vector_real_function_3d, vector_real_function_3d)> &func)
-> response_space {
MADNESS_ASSERT(same_size(A, B));

Expand Down Expand Up @@ -197,46 +192,38 @@ namespace madness {
MADNESS_ASSERT(size() > 0);
MADNESS_ASSERT(same_size(*this, rhs_y));// assert that same size

auto result = binary_apply(*this, rhs_y, [&](auto xi, auto vi) {
return gaxpy_oop(1.0, xi, 1.0, vi, false);
});
auto result =
binary_apply(*this, rhs_y, [&](auto xi, auto vi) { return gaxpy_oop(1.0, xi, 1.0, vi, false); });
return result;
}

response_space operator-(const response_space &rhs_y) const {
MADNESS_ASSERT(size() > 0);
MADNESS_ASSERT(same_size(*this, rhs_y));// assert that same size
auto result = binary_apply(*this, rhs_y, [&](auto xi, auto vi) {
return gaxpy_oop(1.0, xi, -1.0, vi, false);
});
auto result =
binary_apply(*this, rhs_y, [&](auto xi, auto vi) { return gaxpy_oop(1.0, xi, -1.0, vi, false); });
return result;
}

friend response_space operator*(const response_space &y, double a) {
World &world = y.x.at(0).at(0).world();
auto multiply_scalar = [&](vector_real_function_3d &vi) {
madness::scale(world, vi, a, false);
};
auto multiply_scalar = [&](vector_real_function_3d &vi) { madness::scale(world, vi, a, false); };
auto result = y.copy();
inplace_unary_apply(result, multiply_scalar);
return result;
}

friend response_space operator*(double a, response_space &y) {
World &world = y.x.at(0).at(0).world();
auto multiply_scalar = [&](vector_real_function_3d &vi) {
madness::scale(world, vi, a, false);
};
auto multiply_scalar = [&](vector_real_function_3d &vi) { madness::scale(world, vi, a, false); };
auto result = y.copy();
inplace_unary_apply(result, multiply_scalar);
return result;
}

response_space &operator*=(double a) {
World &world = this->x[0][0].world();
auto multiply_scalar = [&](vector_real_function_3d &vi) {
madness::scale(world, vi, a, false);
};
auto multiply_scalar = [&](vector_real_function_3d &vi) { madness::scale(world, vi, a, false); };
inplace_unary_apply(*this, multiply_scalar);
return *this;
}
Expand All @@ -245,9 +232,7 @@ namespace madness {
// g[i][j] = x[i][j] * f
friend response_space operator*(const response_space &a, const Function<double, 3> &f) {
World &world = a.x.at(0).at(0).world();
auto multiply_scalar_function = [&](const vector_real_function_3d &vi) {
return mul(world, f, vi, false);
};
auto multiply_scalar_function = [&](const vector_real_function_3d &vi) { return mul(world, f, vi, false); };
return oop_unary_apply(a, multiply_scalar_function);
}

Expand All @@ -261,9 +246,7 @@ namespace madness {
response_space operator*(const Function<double, 3> &f) {
World &world = x[0][0].world();

auto multiply_scalar_function = [&](const vector_real_function_3d &vi) {
return mul(world, f, vi, false);
};
auto multiply_scalar_function = [&](const vector_real_function_3d &vi) { return mul(world, f, vi, false); };

return oop_unary_apply(*this, multiply_scalar_function);
}
Expand All @@ -273,9 +256,7 @@ namespace madness {
MADNESS_ASSERT(!a[0].empty());
World &world = a[0][0].world();

auto response_transform = [&](const vector_real_function_3d &vi) {
return transform(world, vi, b, false);
};
auto response_transform = [&](const vector_real_function_3d &vi) { return transform(world, vi, b, false); };
return oop_unary_apply(a, response_transform);
}

Expand All @@ -284,8 +265,7 @@ namespace madness {
MADNESS_ASSERT(same_size(*this, b));
auto &world = x[0][0].world();

auto a_plus_equal_b = [&](vector_real_function_3d &a,
const vector_real_function_3d &g) {
auto a_plus_equal_b = [&](vector_real_function_3d &a, const vector_real_function_3d &g) {
gaxpy(world, 1.0, a, 1.0, g, false);
};
binary_inplace(*this, b, a_plus_equal_b);
Expand Down Expand Up @@ -341,40 +321,32 @@ namespace madness {
// Mimicing standard madness calls with these 3
void zero() {
auto &world = x[0][0].world();
std::generate(x.begin(), x.end(),
[&]() { return zero_functions<double, 3>(world, num_orbitals, true); });

/*
for (size_t k = 0; k < num_states; k++) {
x[k] = zero_functions<double, 3>(x[0][0].world(), num_orbitals);
}
*/
for (int i = 0; i < num_states; i++) { x[i] = zero_functions<double, 3>(world, num_orbitals, false); }
}

void compress_rf() {
//for (size_t k = 0; k < num_states; k++) { compress(x[0][0].world(), x[k], true); }
auto &world = x[0][0].world();
std::for_each(x.begin(), x.end(), [&](auto &xi) { compress(world, xi, true); });
// compress only active states
for (auto &i: active) { compress(world, x[i], false); }
world.gop.fence();
}

void reconstruct_rf() {
//for (size_t k = 0; k < num_states; k++) { reconstruct(x[0][0].world(), x[k], true); }
auto &world = x[0][0].world();
std::for_each(x.begin(), x.end(), [&](auto &xi) { reconstruct(world, xi, true); });
// reconstruct only active states
for (auto &i: active) { reconstruct(world, x[i], false); }
world.gop.fence();
}

void truncate_rf() {
truncate_rf(FunctionDefaults<3>::get_thresh());
/*
for (size_t k = 0; k < num_states; k++) {
truncate(x[0][0].world(), x[k], FunctionDefaults<3>::get_thresh(), true);
}
*/
}
void truncate_rf() { truncate_rf(FunctionDefaults<3>::get_thresh()); }

void truncate_rf(double tol) {
auto &world = x[0][0].world();
std::for_each(x.begin(), x.end(), [&](auto &xi) { truncate(world, xi, tol, true); });
// truncate only active states
for (auto &i: active) { truncate(world, x[i], tol, false); }
world.gop.fence();
/*
for (size_t k = 0; k < num_states; k++) { truncate(x[0][0].world(), x[k], tol, true); }
*/
Expand All @@ -394,25 +366,22 @@ namespace madness {
// Scales each state (read: entire row) by corresponding vector element
// new[i] = old[i] * mat[i]
void scale(Tensor<double> &mat) {
for (size_t i = 0; i < num_states; i++)
madness::scale(x[0][0].world(), x[i], mat[i], false);
for (size_t i = 0; i < num_states; i++) madness::scale(x[0][0].world(), x[i], mat[i], false);
// x[i] = x[i] * mat[i];
}

friend bool operator==(const response_space &a, const response_space &y) {
if (!same_size(a, y)) return false;
for (size_t b = 0; b < a.size(); ++b) {
for (size_t k = 0; b < a.size_orbitals(); ++k) {
if ((a[b][k] - y[b][k]).norm2() >
FunctionDefaults<3>::get_thresh())// this may be strict
if ((a[b][k] - y[b][k]).norm2() > FunctionDefaults<3>::get_thresh())// this may be strict
return false;
}
}
return true;
}

friend Tensor<double> response_space_inner(const response_space &a,
const response_space &b) {
friend Tensor<double> response_space_inner(const response_space &a, const response_space &b) {
MADNESS_ASSERT(a.size() > 0);
MADNESS_ASSERT(a.size() == b.size());
MADNESS_ASSERT(!a[0].empty());
Expand Down

0 comments on commit 7c6687d

Please sign in to comment.