Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/MGmol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ extern Timer updateCenters_tm;
std::set<int> Signal::recv_;

template <class OrbitalsType>
MGmol<OrbitalsType>::MGmol(MPI_Comm comm, std::ostream& os) : os_(os)
MGmol<OrbitalsType>::MGmol(MPI_Comm comm, std::ostream& os,
std::string input_filename, std::string lrs_filename,
std::string constraints_filename)
: os_(os)
{
comm_ = comm;

Expand All @@ -136,6 +139,14 @@ MGmol<OrbitalsType>::MGmol(MPI_Comm comm, std::ostream& os) : os_(os)
forces_ = nullptr;

energy_ = nullptr;

setupFromInput(input_filename);

setupLRs(lrs_filename);

setupConstraintsFromInput(constraints_filename);

setup();
}

template <class OrbitalsType>
Expand Down
14 changes: 8 additions & 6 deletions src/MGmol.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ class MGmol : public MGmolInterface
template <typename MemorySpaceType>
int initial();
void initialMasks();
int setupLRsFromInput(const std::string input_file);
int setupLRsFromInput(const std::string filename);

void setup();
int setupLRs(const std::string input_file) override;
int setupFromInput(const std::string input_file) override;
int setupConstraintsFromInput(const std::string input_file) override;

// timers
static Timer total_tm_;
Expand All @@ -166,7 +171,8 @@ class MGmol : public MGmolInterface
public:
Electrostatic* electrostat_;

MGmol(MPI_Comm comm, std::ostream& os);
MGmol(MPI_Comm comm, std::ostream& os, std::string input_filename,
std::string lrs_filename, std::string constraints_filename);

~MGmol() override;

Expand Down Expand Up @@ -274,10 +280,6 @@ class MGmol : public MGmolInterface
void set_forces(std::vector<std::vector<double>>& f);
int nions() { return ions_->getNumIons(); }
double getTotalEnergy();
void setup();
int setupLRs(const std::string input_file) override;
int setupFromInput(const std::string input_file) override;
int setupConstraintsFromInput(const std::string input_file) override;
void cleanup();
void geomOptimSetup();
void geomOptimQuench();
Expand Down
27 changes: 9 additions & 18 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,35 +108,26 @@ int main(int argc, char** argv)
int ret = ct.checkOptions();
if (ret < 0) return ret;

unsigned ngpts[3] = { ct.ngpts_[0], ct.ngpts_[1], ct.ngpts_[2] };
double origin[3] = { ct.ox_, ct.oy_, ct.oz_ };
const double cell[3] = { ct.lx_, ct.ly_, ct.lz_ };
Mesh::setup(mmpi.commSpin(), ngpts, origin, cell, ct.lap_type);

mmpi.bcastGlobal(input_filename);
mmpi.bcastGlobal(lrs_filename);

// Enter main scope
{
MGmolInterface* mgmol;
if (ct.isLocMode())
mgmol = new MGmol<LocGridOrbitals>(global_comm, *MPIdata::sout);
mgmol = new MGmol<LocGridOrbitals>(global_comm, *MPIdata::sout,
input_filename, lrs_filename, constraints_filename);
else
mgmol
= new MGmol<ExtendedGridOrbitals>(global_comm, *MPIdata::sout);

unsigned ngpts[3] = { ct.ngpts_[0], ct.ngpts_[1], ct.ngpts_[2] };
double origin[3] = { ct.ox_, ct.oy_, ct.oz_ };
const double cell[3] = { ct.lx_, ct.ly_, ct.lz_ };
Mesh::setup(mmpi.commSpin(), ngpts, origin, cell, ct.lap_type);

mgmol->setupFromInput(input_filename);

if (ct.isLocMode() || ct.init_loc == 1) mgmol->setupLRs(lrs_filename);

mgmol->setupConstraintsFromInput(constraints_filename);

mgmol_setup();
mgmol = new MGmol<ExtendedGridOrbitals>(global_comm, *MPIdata::sout,
input_filename, lrs_filename, constraints_filename);

if (!tcheck)
{
mgmol->setup();

mgmol->run();
}
else
Expand Down
27 changes: 1 addition & 26 deletions src/mgmol_run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,13 @@ int mgmol_init(MPI_Comm comm)
return 0;
}

int mgmol_setup()
int mgmol_check()
{
Control& ct = *(Control::instance());
Mesh* mymesh = Mesh::instance();
const pb::PEenv& myPEenv = mymesh->peenv();
MGmol_MPI& mmpi = *(MGmol_MPI::instance());

ct.checkNLrange();

LocGridOrbitals::setDotProduct(ct.dot_product_type);

if (!ct.short_sighted)
{
MatricesBlacsContext::instance().setup(mmpi.commSpin(), ct.numst);

dist_matrix::DistMatrix<DISTMATDTYPE>::setBlockSize(64);

dist_matrix::DistMatrix<DISTMATDTYPE>::setDefaultBlacsContext(
MatricesBlacsContext::instance().bcxt());

ReplicatedWorkSpace<double>::instance().setup(ct.numst);

dist_matrix::SparseDistMatrix<DISTMATDTYPE>::setNumTasksPerPartitioning(
128);

int npes = mmpi.size();
setSparseDistMatriConsolidationNumber(npes);
}
#ifdef HAVE_MAGMA
ReplicatedMatrix::setMPIcomm(mmpi.commSpin());
#endif

if (myPEenv.color() > 0)
{
std::cerr << "Code should be called with " << myPEenv.n_mpi_tasks()
Expand Down
2 changes: 1 addition & 1 deletion src/mgmol_run.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
// This file is part of MGmol. For details, see https://github.com/llnl/mgmol.
// Please also read this link https://github.com/llnl/mgmol/LICENSE
int mgmol_init(MPI_Comm comm);
int mgmol_setup();
int mgmol_check();
void mgmol_finalize();
36 changes: 36 additions & 0 deletions src/setup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#include "LocGridOrbitals.h"
#include "MGmol.h"
#include "Potentials.h"
#include "ReplicatedWorkSpace.h"
#include "SparseDistMatrix.h"

#include "mgmol_run.h"

template <class OrbitalsType>
int MGmol<OrbitalsType>::setupFromInput(const std::string filename)
Expand Down Expand Up @@ -51,6 +55,36 @@ int MGmol<OrbitalsType>::setupFromInput(const std::string filename)
ct.setTolEnergy();
ct.setSpreadRadius();

ct.checkNLrange();

// now that we know the number of states, we can set a few other static
// data
if (!ct.short_sighted)
{
MatricesBlacsContext::instance().setup(mmpi.commSpin(), ct.numst);

dist_matrix::DistMatrix<DISTMATDTYPE>::setBlockSize(64);

dist_matrix::DistMatrix<DISTMATDTYPE>::setDefaultBlacsContext(
MatricesBlacsContext::instance().bcxt());

ReplicatedWorkSpace<double>::instance().setup(ct.numst);

dist_matrix::SparseDistMatrix<DISTMATDTYPE>::setNumTasksPerPartitioning(
128);

int npes = mmpi.size();
setSparseDistMatriConsolidationNumber(npes);
}

#ifdef HAVE_MAGMA
ReplicatedMatrix::setMPIcomm(mmpi.commSpin());
#endif

LocGridOrbitals::setDotProduct(ct.dot_product_type);

mgmol_check();

return 0;
}

Expand All @@ -59,6 +93,8 @@ int MGmol<OrbitalsType>::setupLRs(const std::string filename)
{
Control& ct = *(Control::instance());

if (!(ct.isLocMode() || ct.init_loc == 1)) return 0;

// create localization regions
Mesh* mymesh = Mesh::instance();
const pb::Grid& mygrid = mymesh->grid();
Expand Down