Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler): distributed execution - on-demand key transfer to rem… #720

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef struct FFT {
FFT() = delete;
FFT(size_t polynomial_size);
FFT(FFT &other) = delete;
FFT(const FFT &other) = delete;
FFT(FFT &&other);
~FFT();

Expand All @@ -42,7 +43,7 @@ typedef struct RuntimeContext {

RuntimeContext() = delete;
RuntimeContext(ServerKeyset serverKeyset);
~RuntimeContext() {
virtual ~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
for (int i = 0; i < num_devices; ++i) {
if (bsk_gpu[i] != nullptr)
Expand All @@ -53,27 +54,30 @@ typedef struct RuntimeContext {
#endif
};

const uint64_t *keyswitch_key_buffer(size_t keyId) {
virtual const uint64_t *keyswitch_key_buffer(size_t keyId) {
return serverKeyset.lweKeyswitchKeys[keyId].getBuffer().data();
}

const std::complex<double> *fourier_bootstrap_key_buffer(size_t keyId) {
virtual const std::complex<double> *
fourier_bootstrap_key_buffer(size_t keyId) {
return fourier_bootstrap_keys[keyId]->data();
}

const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
virtual const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
return serverKeyset.packingKeyswitchKeys[keyId].getRawPtr();
}

const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
virtual const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }

const ServerKeyset getKeys() const { return serverKeyset; }

private:
protected:
ServerKeyset serverKeyset;
std::vector<std::shared_ptr<std::vector<std::complex<double>>>>
fourier_bootstrap_keys;
std::vector<FFT> ffts;
std::pair<FFT, std::shared_ptr<std::vector<std::complex<double>>>>
convert_to_fourier_domain(LweBootstrapKey &bsk);

#ifdef CONCRETELANG_CUDA_SUPPORT
public:
Expand Down Expand Up @@ -144,6 +148,24 @@ typedef struct RuntimeContext {
#endif
} RuntimeContext;

struct DistributedRuntimeContext : public RuntimeContext {

using RuntimeContext::RuntimeContext;
const uint64_t *keyswitch_key_buffer(size_t keyId) override;
const std::complex<double> *
fourier_bootstrap_key_buffer(size_t keyId) override;
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) override;
const struct Fft *fft(size_t keyId) override;

private:
void getBSKonNode(size_t keyId);
std::mutex cm_guard;
std::map<size_t, LweKeyswitchKey> ksks;
std::map<size_t, std::shared_ptr<std::vector<std::complex<double>>>> fbks;
std::map<size_t, FFT> dffts;
std::map<size_t, PackingKeyswitchKey> pksks;
};

} // namespace concretelang
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ namespace concretelang {
namespace dfr {

struct RuntimeContextManager;
namespace {
static void *dl_handle;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
extern RuntimeContextManager *_dfr_node_level_runtime_context_manager;

template <typename LweKeyType> struct KeyWrapper {
std::vector<LweKeyType> keys;
Expand Down Expand Up @@ -109,21 +106,40 @@ struct RuntimeContextManager {
// TODO: this is only ok so long as we don't change keys. Once we
// use multiple keys, should have a map.
RuntimeContext *context;
bool allocated = false;
bool lazy_key_transfer = false;

RuntimeContextManager() {
RuntimeContextManager(bool lazy = false) : lazy_key_transfer(lazy) {
context = nullptr;
_dfr_node_level_runtime_context_manager = this;
}

void setContext(void *ctx) {
assert(context == nullptr &&
"Only one RuntimeContext can be used at a time.");
context = (RuntimeContext *)ctx;

if (lazy_key_transfer) {
if (!_dfr_is_root_node()) {
context =
new mlir::concretelang::DistributedRuntimeContext(ServerKeyset());
allocated = true;
}
return;
}

// When the root node does not require a context, we still need to
// broadcast an empty keyset to remote nodes as they cannot know
// ahead of time and avoid waiting for the broadcast. Instantiate
// an empty context for this.
if (_dfr_is_root_node() && ctx == nullptr) {
context = new mlir::concretelang::RuntimeContext(ServerKeyset());
allocated = true;
}

// Root node broadcasts the evaluation keys and each remote
// instantiates a local RuntimeContext.
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;

KeyWrapper<LweKeyswitchKey> kskw(context->getKeys().lweKeyswitchKeys);
KeyWrapper<LweBootstrapKey> bskw(context->getKeys().lweBootstrapKeys);
KeyWrapper<PackingKeyswitchKey> pkskw(
Expand Down Expand Up @@ -153,12 +169,23 @@ struct RuntimeContextManager {

void clearContext() {
if (context != nullptr)
delete context;
// On root node deallocate only if allocated independently here
if (!_dfr_is_root_node() || allocated)
delete context;
context = nullptr;
}
};

KeyWrapper<LweKeyswitchKey> getKsk(size_t keyId);
KeyWrapper<LweBootstrapKey> getBsk(size_t keyId);
KeyWrapper<PackingKeyswitchKey> getPKsk(size_t keyId);

HPX_DEFINE_PLAIN_ACTION(getKsk, _get_ksk_action);
HPX_DEFINE_PLAIN_ACTION(getBsk, _get_bsk_action);
HPX_DEFINE_PLAIN_ACTION(getPKsk, _get_pksk_action);

} // namespace dfr
} // namespace concretelang
} // namespace mlir

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ namespace dfr {

struct WorkFunctionRegistry;
namespace {
static void *dl_handle;
static WorkFunctionRegistry *_dfr_node_level_work_function_registry;
}
} // namespace

struct WorkFunctionRegistry {
WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
add_compile_options(-fsized-deallocation)

if(CONCRETELANG_CUDA_SUPPORT)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp GPUDFG.cpp)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp
GPUDFG.cpp)
target_link_libraries(ConcretelangRuntime PRIVATE hwloc)
else()
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp StreamEmulator.cpp)
add_library(ConcretelangRuntime SHARED context.cpp simulation.cpp wrappers.cpp DFRuntime.cpp key_manager.cpp
StreamEmulator.cpp)
endif()

add_dependencies(ConcretelangRuntime concrete_cpu concrete_cpu_noise_model concrete-protocol)
Expand Down
19 changes: 13 additions & 6 deletions compilers/concrete-compiler/compiler/lib/Runtime/DFRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
num_nodes = hpx::get_num_localities().get();

new WorkFunctionRegistry();
new RuntimeContextManager();

char *env = getenv("DFR_LAZY_KEY_TRANSFER");
bool lazy = false;
if (env != nullptr)
if (!strncmp(env, "True", 4) || !strncmp(env, "true", 4) ||
!strncmp(env, "On", 2) || !strncmp(env, "on", 2) ||
!strncmp(env, "1", 1))
lazy = true;
new RuntimeContextManager(lazy);

_dfr_jit_phase_barrier = new hpx::distributed::barrier(
"phase_barrier", num_nodes, hpx::get_locality_id());
_dfr_startup_barrier = new hpx::distributed::barrier(
Expand Down Expand Up @@ -351,14 +360,12 @@ void _dfr_start(int64_t use_dfr_p, void *ctx) {

assert(init_guard == active && "DFR runtime failed to initialise");

// If DFR is used and a runtime context is needed, and execution is
// distributed, then broadcast from root to all compute nodes.
if (num_nodes > 1 && (ctx || !_dfr_is_root_node())) {
// If execution is distributed, then broadcast (possibly an empty)
// context from root to all compute nodes.
if (num_nodes > 1) {
BEGIN_TIME(&broadcast_timer);
_dfr_node_level_runtime_context_manager->setContext(ctx);
}
// If this is not JIT, then the remote nodes never reach _dfr_stop,
// so root should not instantiate this barrier.
if (_dfr_is_root_node())
_dfr_startup_barrier->wait();

Expand Down
Loading
Loading