Skip to content

Commit

Permalink
514 create a builder class that can make a backward euler solver (#527)
Browse files Browse the repository at this point in the history
* starting to piece together the builder

* hiding some template details behind a solver impl

* removing a lot of template template parameters

* removing more template templates, compiling with gcc

* chaning function name and adding a type alias

* adding state policy in progress

* update for new state policies

* getting a solver builder to work

* it compiles with clang on my machine

* making sure profiling option compiles

* attempting to get cuda to compile

* i think nvidia all compiles now

* starting to make the jit compiler a singleton

* trying to pull my old changes

* saving progress just in case since I've made some

* it compiles

* removing more lambdas

* suppresing openmp checks

* updating suppresion file

* addressing PR comments

* correcting implementation

* reverting readme example

* addressing PR comments

---------

Co-authored-by: Matt Dawson <mattdawson@ucar.edu>
  • Loading branch information
K20shores and mattldawson committed May 29, 2024
1 parent 0466a1f commit 369aede
Show file tree
Hide file tree
Showing 88 changed files with 1,749 additions and 2,283 deletions.
1 change: 0 additions & 1 deletion docker/Dockerfile.intel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ RUN apt update \
ca-certificates \
cmake \
cmake-curses-gui \
curl \
libcurl4-openssl-dev \
libhdf5-dev \
m4 \
Expand Down
2 changes: 2 additions & 0 deletions docker/Dockerfile.llvm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ RUN dnf -y update \
make \
zlib-devel \
llvm-devel \
openmpi-devel \
valgrind \
&& dnf clean all

Expand All @@ -24,6 +25,7 @@ RUN mkdir /build \
-D MICM_ENABLE_CLANG_TIDY:BOOL=FALSE \
-D MICM_ENABLE_LLVM:BOOL=TRUE \
-D MICM_ENABLE_MEMCHECK:BOOL=TRUE \
-D MICM_ENABLE_OPENMP:BOOL=TRUE \
../micm \
&& make install -j 8

Expand Down
21 changes: 3 additions & 18 deletions examples/profile_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace fs = std::filesystem;
using namespace micm;

template<template<class> class MatrixType, template<class> class SparseMatrixType>
template<template<class> class MatrixType, class SparseMatrixType>
int Run(const char* filepath, const char* initial_conditions, const std::string& matrix_ordering_type)
{
using SolverType = RosenbrockSolver<MatrixType, SparseMatrixType>;
Expand Down Expand Up @@ -114,24 +114,9 @@ int Run(const char* filepath, const char* initial_conditions, const std::string&
return 0;
}

template<class T>
using SparseMatrixParam = micm::SparseMatrix<T>;
template<class T>
using Vector1MatrixParam = micm::VectorMatrix<T, 1>;
template<class T>
using Vector10MatrixParam = micm::VectorMatrix<T, 10>;
template<class T>
using Vector100MatrixParam = micm::VectorMatrix<T, 100>;
template<class T>
template<typename T>
using Vector1000MatrixParam = micm::VectorMatrix<T, 1000>;
template<class T>
using Vector1SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1>>;
template<class T>
using Vector10SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<10>>;
template<class T>
using Vector100SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<100>>;
template<class T>
using Vector1000SparseMatrixParam = micm::SparseMatrix<T, micm::SparseMatrixVectorOrdering<1000>>;
using Vector1000SparseMatrixParam = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<1000>>;

int main(const int argc, const char* argv[])
{
Expand Down
99 changes: 60 additions & 39 deletions include/micm/jit/jit_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
enum class MicmJitErrc
{
InvalidMatrix = MICM_JIT_ERROR_CODE_INVALID_MATRIX,
MissingJitFunction = MICM_JIT_ERROR_CODE_MISSING_JIT_FUNCTION
MissingJitFunction = MICM_JIT_ERROR_CODE_MISSING_JIT_FUNCTION,
FailedToBuild = MICM_JIT_ERROR_CODE_FAILED_TO_BUILD
};

namespace std
Expand Down Expand Up @@ -84,6 +85,7 @@ inline std::error_code make_error_code(MicmJitErrc e)
namespace micm
{

// a singleton class
class JitCompiler
{
private:
Expand All @@ -99,23 +101,23 @@ namespace micm
llvm::orc::JITDylib &main_lib_;

public:
JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
// Delete the copy constructor and assignment operator
JitCompiler(const JitCompiler &) = delete;
JitCompiler &operator=(const JitCompiler &) = delete;

static JitCompiler &GetInstance()
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
static std::unique_ptr<JitCompiler> instance;
if (!instance)
{
auto expectedInstance = Create();
if (!expectedInstance)
{
throw std::system_error(make_error_code(MicmJitErrc::FailedToBuild));
}
instance = std::move(*expectedInstance);
}
return *instance;
}

~JitCompiler()
Expand All @@ -124,28 +126,6 @@ namespace micm
execution_session_->reportError(std::move(Err));
}

static llvm::Expected<std::shared_ptr<JitCompiler>> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return std::make_shared<JitCompiler>(
std::move(execution_session), std::move(machine_builder), std::move(*data_layout));
}

const llvm::DataLayout &GetDataLayout() const
{
return data_layout_;
Expand All @@ -171,6 +151,47 @@ namespace micm
}

private:
static llvm::Expected<std::unique_ptr<JitCompiler>> Create()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

auto EPC = llvm::orc::SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto execution_session = std::make_unique<llvm::orc::ExecutionSession>(std::move(*EPC));

llvm::orc::JITTargetMachineBuilder machine_builder(execution_session->getExecutorProcessControl().getTargetTriple());

auto data_layout = machine_builder.getDefaultDataLayoutForTarget();
if (!data_layout)
return data_layout.takeError();

return llvm::Expected<std::unique_ptr<JitCompiler>>(std::unique_ptr<JitCompiler>(
new JitCompiler(std::move(execution_session), std::move(machine_builder), std::move(*data_layout))));
}

JitCompiler(
std::unique_ptr<llvm::orc::ExecutionSession> execution_session,
llvm::orc::JITTargetMachineBuilder machine_builder,
llvm::DataLayout data_layout)
: execution_session_(std::move(execution_session)),
data_layout_(std::move(data_layout)),
mangle_(*this->execution_session_, this->data_layout_),
object_layer_(*this->execution_session_, []() { return std::make_unique<llvm::SectionMemoryManager>(); }),
compile_layer_(
*this->execution_session_,
object_layer_,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(machine_builder))),
optimize_layer_(*this->execution_session_, compile_layer_, OptimizeModule),
main_lib_(this->execution_session_->createBareJITDylib("<main>"))
{
main_lib_.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(data_layout_.getGlobalPrefix())));
}

static llvm::Expected<llvm::orc::ThreadSafeModule> OptimizeModule(
llvm::orc::ThreadSafeModule threadsafe_module,
const llvm::orc::MaterializationResponsibility &responsibility)
Expand Down
16 changes: 8 additions & 8 deletions include/micm/jit/jit_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace micm
{
bool generated_ = false;
std::string name_;
std::shared_ptr<JitCompiler> compiler_;
JitCompiler* compiler_;

public:
std::unique_ptr<llvm::LLVMContext> context_;
Expand All @@ -91,7 +91,7 @@ namespace micm
JitFunction() = delete;

friend class JitFunctionBuilder;
static JitFunctionBuilder Create(std::shared_ptr<JitCompiler> compiler);
static JitFunctionBuilder Create();
JitFunction(JitFunctionBuilder& function_builder);

/// @brief Generates the function
Expand Down Expand Up @@ -138,23 +138,23 @@ namespace micm

class JitFunctionBuilder
{
std::shared_ptr<JitCompiler> compiler_;
JitCompiler* compiler_;
std::string name_;
std::vector<std::pair<std::string, JitType>> arguments_;
JitType return_type_{ JitType::Void };
friend class JitFunction;

public:
JitFunctionBuilder() = delete;
JitFunctionBuilder(std::shared_ptr<JitCompiler> compiler);
JitFunctionBuilder(JitCompiler& compiler);
JitFunctionBuilder& SetName(const std::string& name);
JitFunctionBuilder& SetArguments(const std::vector<std::pair<std::string, JitType>>& arguments);
JitFunctionBuilder& SetReturnType(JitType type);
};

inline JitFunctionBuilder JitFunction::Create(std::shared_ptr<JitCompiler> compiler)
inline JitFunctionBuilder JitFunction::Create()
{
return JitFunctionBuilder{ compiler };
return JitFunctionBuilder{ JitCompiler::GetInstance() };
}

JitFunction::JitFunction(JitFunctionBuilder& function_builder)
Expand Down Expand Up @@ -296,8 +296,8 @@ namespace micm
return TmpB.CreateAlloca(type, 0, var_name.c_str());
}

inline JitFunctionBuilder::JitFunctionBuilder(std::shared_ptr<JitCompiler> compiler)
: compiler_(compiler){};
inline JitFunctionBuilder::JitFunctionBuilder(JitCompiler& compiler)
: compiler_(&compiler){};

inline JitFunctionBuilder& JitFunctionBuilder::SetName(const std::string& name)
{
Expand Down
42 changes: 21 additions & 21 deletions include/micm/process/cuda_process_set.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ namespace micm
template<typename OrderingPolicy>
void SetJacobianFlatIds(const SparseMatrix<double, OrderingPolicy>& matrix);

template<template<class> typename MatrixPolicy>
requires(CudaMatrix<MatrixPolicy<double>>&& VectorizableDense<MatrixPolicy<double>>) void AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const;

template<template<class> typename MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy<double>>) void AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const;
template<typename MatrixPolicy>
requires(CudaMatrix<MatrixPolicy>&& VectorizableDense<MatrixPolicy>) void AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const;

template<typename MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy>) void AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const;

template<class MatrixPolicy, class SparseMatrixPolicy>
requires(
Expand Down Expand Up @@ -101,24 +101,24 @@ namespace micm
micm::cuda::CopyJacobiFlatId(hoststruct, this->devstruct_);
}

template<template<class> class MatrixPolicy>
requires(CudaMatrix<MatrixPolicy<double>>&& VectorizableDense<MatrixPolicy<double>>) inline void CudaProcessSet::
template<class MatrixPolicy>
requires(CudaMatrix<MatrixPolicy>&& VectorizableDense<MatrixPolicy>) inline void CudaProcessSet::
AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const
{
auto forcing_param = forcing.AsDeviceParam(); // we need to update forcing so it can't be constant and must be an lvalue
micm::cuda::AddForcingTermsKernelDriver(
rate_constants.AsDeviceParam(), state_variables.AsDeviceParam(), forcing_param, this->devstruct_);
}

// call the function from the base class
template<template<class> class MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy<double>>) inline void CudaProcessSet::AddForcingTerms(
const MatrixPolicy<double>& rate_constants,
const MatrixPolicy<double>& state_variables,
MatrixPolicy<double>& forcing) const
template<class MatrixPolicy>
requires(!CudaMatrix<MatrixPolicy>) inline void CudaProcessSet::AddForcingTerms(
const MatrixPolicy& rate_constants,
const MatrixPolicy& state_variables,
MatrixPolicy& forcing) const
{
AddForcingTerms(rate_constants, state_variables, forcing);
}
Expand Down
Loading

0 comments on commit 369aede

Please sign in to comment.