Skip to content

Commit

Permalink
Move to deferred factory, fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 12, 2024
1 parent 9da6f53 commit 24c6f7d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 57 deletions.
55 changes: 31 additions & 24 deletions core/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::apply_dense_impl(
}

if (this->coarse_solver_ != nullptr && this->galerkin_ops_ != nullptr) {
auto restrict = this->galerkin_ops_->get_restrict_op();
auto prolong = this->galerkin_ops_->get_prolong_op();
auto restrict = as<gko::multigrid::MultigridLevel>(this->galerkin_ops_)
->get_restrict_op();
auto prolong = as<gko::multigrid::MultigridLevel>(this->galerkin_ops_)
->get_prolong_op();
GKO_ASSERT(this->half_ != nullptr);

restrict->apply(dense_b, this->csol_);
this->coarse_solver_->apply(this->csol_, this->csol_);
Expand Down Expand Up @@ -119,28 +122,32 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
}


if (parameters_.galerkin_ops_factory && parameters_.coarse_solver_factory) {
this->galerkin_ops_ = as<multigrid::MultigridLevel>(
share(parameters_.galerkin_ops_factory->generate(dist_mat)));
auto coarse =
as<experimental::distributed::Matrix<ValueType, LocalIndexType,
GlobalIndexType>>(
this->galerkin_ops_->get_coarse_op());
auto exec = coarse->get_executor();
auto comm = coarse->get_communicator();
this->coarse_solver_ =
parameters_.coarse_solver_factory->generate(coarse);
// TODO: Set correct rhs and stride.
auto cs_ncols = 1; // dense_x->get_size()[1];
auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0];
auto cs_global_nrows = coarse->get_size()[0];
auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols);
auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols);
this->csol_ = gko::share(dist_vec::create(exec, comm, cs_global_size,
cs_local_size,
1 /*dense_x->get_stride()*/));
// this->temp_ = this->csol->clone();
this->half_ = gko::share(gko::initialize<Vector>({0.5}, exec));
if (parameters_.galerkin_ops && parameters_.coarse_solver) {
this->galerkin_ops_ =
share(parameters_.galerkin_ops->generate(system_matrix));
if (as<gko::multigrid::MultigridLevel>(this->galerkin_ops_)
->get_coarse_op()) {
auto coarse =
as<experimental::distributed::Matrix<ValueType, LocalIndexType,
GlobalIndexType>>(
as<gko::multigrid::MultigridLevel>(this->galerkin_ops_)
->get_coarse_op());
auto exec = coarse->get_executor();
auto comm = coarse->get_communicator();
this->coarse_solver_ =
share(parameters_.coarse_solver->generate(coarse));
// TODO: Set correct rhs and stride.
auto cs_ncols = 1; // dense_x->get_size()[1];
auto cs_local_nrows = coarse->get_local_matrix()->get_size()[0];
auto cs_global_nrows = coarse->get_size()[0];
auto cs_local_size = dim<2>(cs_local_nrows, cs_ncols);
auto cs_global_size = dim<2>(cs_global_nrows, cs_ncols);
this->csol_ = gko::share(
dist_vec::create(exec, comm, cs_global_size, cs_local_size,
1 /*dense_x->get_stride()*/));
// this->temp_ = this->csol->clone();
this->half_ = gko::share(gko::initialize<Vector>({0.5}, exec));
}
}
}

Expand Down
36 changes: 14 additions & 22 deletions core/test/mpi/distributed/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
#include "core/test/utils.hpp"


namespace {


template <typename ValueLocalGlobalIndexType>
class SchwarzFactory : public ::testing::Test {
protected:
Expand Down Expand Up @@ -44,8 +41,8 @@ class SchwarzFactory : public ::testing::Test {
{
schwarz = Schwarz::build()
.with_local_solver(jacobi_factory)
.with_galerkin_ops_factory(pgm_factory)
.with_coarse_solver_factory(pgm_factory)
.with_galerkin_ops(pgm_factory)
.with_coarse_solver(cg_factory)
.on(exec)
->generate(mtx);
}
Expand All @@ -63,10 +60,10 @@ class SchwarzFactory : public ::testing::Test {
ASSERT_EQ(a->get_size(), b->get_size());
ASSERT_EQ(a->get_parameters().local_solver,
b->get_parameters().local_solver);
ASSERT_EQ(a->get_parameters().galerkin_ops_factory,
b->get_parameters().galerkin_ops_factory);
ASSERT_EQ(a->get_parameters().coarse_solver_factory,
b->get_parameters().coarse_solver_factory);
ASSERT_EQ(a->get_parameters().galerkin_ops,
b->get_parameters().galerkin_ops);
ASSERT_EQ(a->get_parameters().coarse_solver,
b->get_parameters().coarse_solver);
}

std::shared_ptr<const gko::Executor> exec;
Expand Down Expand Up @@ -96,15 +93,13 @@ TYPED_TEST(SchwarzFactory, CanSetLocalFactory)

TYPED_TEST(SchwarzFactory, CanSetGalerkinOpsFactory)
{
ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops_factory,
this->pgm_factory);
ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops, this->pgm_factory);
}


TYPED_TEST(SchwarzFactory, CanSetCoarseSolverFactory)
{
ASSERT_EQ(this->schwarz->get_parameters().coarse_solver_factory,
this->cg_factory);
ASSERT_EQ(this->schwarz->get_parameters().coarse_solver, this->cg_factory);
}


Expand All @@ -128,8 +123,8 @@ TYPED_TEST(SchwarzFactory, CanBeCopied)
auto cg = gko::share(Cg::build().on(this->exec));
auto copy = Schwarz::build()
.with_local_solver(bj)
.with_galerkin_ops_factory(pgm)
.with_coarse_solver_factory(cg)
.with_galerkin_ops(pgm)
.with_coarse_solver(cg)
.on(this->exec)
->generate(Mtx::create(this->exec, MPI_COMM_WORLD));

Expand All @@ -152,8 +147,8 @@ TYPED_TEST(SchwarzFactory, CanBeMoved)
auto cg = gko::share(Cg::build().on(this->exec));
auto copy = Schwarz::build()
.with_local_solver(bj)
.with_galerkin_ops_factory(pgm)
.with_coarse_solver_factory(cg)
.with_galerkin_ops(pgm)
.with_coarse_solver(cg)
.on(this->exec)
->generate(Mtx::create(this->exec, MPI_COMM_WORLD));

Expand All @@ -169,8 +164,8 @@ TYPED_TEST(SchwarzFactory, CanBeCleared)

ASSERT_EQ(this->schwarz->get_size(), gko::dim<2>(0, 0));
ASSERT_EQ(this->schwarz->get_parameters().local_solver, nullptr);
ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops_factory, nullptr);
ASSERT_EQ(this->schwarz->get_parameters().coarse_solver_factory, nullptr);
ASSERT_EQ(this->schwarz->get_parameters().galerkin_ops, nullptr);
ASSERT_EQ(this->schwarz->get_parameters().coarse_solver, nullptr);
}


Expand All @@ -185,6 +180,3 @@ TYPED_TEST(SchwarzFactory, PassExplicitFactory)

ASSERT_EQ(factory->get_parameters().local_solver, jacobi_factory);
}


} // namespace
11 changes: 5 additions & 6 deletions examples/distributed-solver/distributed-solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,11 @@ int main(int argc, char* argv[])
if (schw_type == "multi-level") {
Ainv =
solver::build()
.with_preconditioner(
schwarz::build()
.with_local_solver(local_solver)
.with_galerkin_ops_factory(pgm_fac)
.with_coarse_solver_factory(coarse_solver)
.on(exec))
.with_preconditioner(schwarz::build()
.with_local_solver(local_solver)
.with_galerkin_ops(pgm_fac)
.with_coarse_solver(coarse_solver)
.on(exec))
.with_criteria(
gko::stop::Iteration::build().with_max_iters(num_iters).on(
exec),
Expand Down
11 changes: 6 additions & 5 deletions include/ginkgo/core/distributed/preconditioner/schwarz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ginkgo/core/distributed/matrix.hpp>
#include <ginkgo/core/distributed/vector.hpp>
#include <ginkgo/core/multigrid/multigrid_level.hpp>
#include <ginkgo/core/solver/solver_base.hpp>


namespace gko {
Expand Down Expand Up @@ -80,14 +81,14 @@ class Schwarz
* Operator factory to generate the triplet (prolong_op, coarse_op,
* restrict_op).
*/
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
galerkin_ops_factory, nullptr);
std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
galerkin_ops);

/**
* Coarse solver factory.
*/
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
coarse_solver_factory, nullptr);
std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
coarse_solver);
};
GKO_ENABLE_LIN_OP_FACTORY(Schwarz, parameters, Factory);
GKO_ENABLE_BUILD_METHOD(Factory);
Expand Down Expand Up @@ -141,7 +142,7 @@ class Schwarz
void set_solver(std::shared_ptr<const LinOp> new_solver);

std::shared_ptr<const LinOp> local_solver_;
std::shared_ptr<const multigrid::MultigridLevel> galerkin_ops_;
std::shared_ptr<const LinOp> galerkin_ops_;
std::shared_ptr<const LinOp> coarse_solver_;
std::shared_ptr<LinOp> csol_;
std::shared_ptr<const LinOp> half_;
Expand Down
28 changes: 28 additions & 0 deletions test/mpi/preconditioner/schwarz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <ginkgo/core/log/logger.hpp>
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/multigrid/pgm.hpp>
#include <ginkgo/core/preconditioner/jacobi.hpp>
#include <ginkgo/core/solver/bicgstab.hpp>
#include <ginkgo/core/solver/cg.hpp>
Expand Down Expand Up @@ -59,6 +60,9 @@ class SchwarzPreconditioner : public CommonMpiTestFixture {
using solver_type = gko::solver::Bicgstab<value_type>;
using local_prec_type =
gko::preconditioner::Jacobi<value_type, local_index_type>;
using coarse_solver_type =
gko::preconditioner::Jacobi<value_type, local_index_type>;
using galerkin_ops_type = gko::multigrid::Pgm<value_type, local_index_type>;
using local_matrix_type = gko::matrix::Csr<value_type, local_index_type>;
using non_dist_matrix_type =
gko::matrix::Csr<value_type, global_index_type>;
Expand Down Expand Up @@ -125,6 +129,8 @@ class SchwarzPreconditioner : public CommonMpiTestFixture {
std::shared_ptr<gko::LinOpFactory> non_dist_solver_factory;
std::shared_ptr<gko::LinOpFactory> dist_solver_factory;
std::shared_ptr<gko::LinOpFactory> local_solver_factory;
std::shared_ptr<gko::LinOpFactory> pgm_factory;
std::shared_ptr<gko::LinOpFactory> coarse_solver_factory;

void assert_equal_to_non_distributed_vector(
std::shared_ptr<dist_vec_type> dist_vec,
Expand Down Expand Up @@ -271,6 +277,28 @@ TYPED_TEST(SchwarzPreconditioner, CanApplyPreconditioner)
}


TYPED_TEST(SchwarzPreconditioner, CanApplyMultilevelPreconditioner)
{
using value_type = typename TestFixture::value_type;
using prec = typename TestFixture::dist_prec_type;

auto precond_factory = prec::build()
.with_local_solver(this->local_solver_factory)
.with_coarse_solver(this->coarse_solver_factory)
.with_galerkin_ops(this->pgm_factory)
.on(this->exec);
auto local_precond =
this->local_solver_factory->generate(this->non_dist_mat);
auto precond = precond_factory->generate(this->dist_mat);

precond->apply(this->dist_b.get(), this->dist_x.get());
local_precond->apply(this->non_dist_b.get(), this->non_dist_x.get());

this->assert_equal_to_non_distributed_vector(this->dist_x,
this->non_dist_x);
}


TYPED_TEST(SchwarzPreconditioner, CanAdvancedApplyPreconditioner)
{
using value_type = typename TestFixture::value_type;
Expand Down

0 comments on commit 24c6f7d

Please sign in to comment.