Skip to content

Commit

Permalink
test(compiler): make benchmarks framework compatible with distributed…
Browse files Browse the repository at this point in the history
… execution.
  • Loading branch information
antoniupop committed Jan 5, 2024
1 parent 2bd2eca commit 920fe72
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "../end_to_end_tests/end_to_end_test.h"
#include "concretelang/Common/Compat.h"
#include "concretelang/TestLib/TestCircuit.h"
#include <concretelang/Runtime/DFRuntime.hpp>

#include <benchmark/benchmark.h>
#include <filesystem>
Expand Down Expand Up @@ -51,11 +52,14 @@ static void BM_ExportArguments(benchmark::State &state,
auto test = description.tests[0];
auto inputArguments = std::vector<TransportValue>();
inputArguments.reserve(test.inputs.size());

auto client = tc.getClientCircuit().value();
for (auto _ : state) {
for (size_t i = 0; i < test.inputs.size(); i++) {
auto input = client.prepareInput(test.inputs[i].getValue(), i).value();
inputArguments.push_back(input);
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
for (auto _ : state) {
for (size_t i = 0; i < test.inputs.size(); i++) {
auto input = client.prepareInput(test.inputs[i].getValue(), i).value();
inputArguments.push_back(input);
}
}
inputArguments.resize(0);
}
Expand All @@ -74,10 +78,12 @@ static void BM_Evaluate(benchmark::State &state, EndToEndDesc description,
auto inputArguments = std::vector<TransportValue>();
inputArguments.reserve(test.inputs.size());

for (size_t i = 0; i < test.inputs.size(); i++) {
auto input =
clientCircuit.prepareInput(test.inputs[i].getValue(), i).value();
inputArguments.push_back(input);
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
for (size_t i = 0; i < test.inputs.size(); i++) {
auto input =
clientCircuit.prepareInput(test.inputs[i].getValue(), i).value();
inputArguments.push_back(input);
}
}

auto serverCircuit = tc.getServerCircuit().value();
Expand All @@ -101,7 +107,8 @@ void registerEndToEndBenchmark(std::string suiteName,
std::vector<EndToEndDesc> descriptions,
mlir::concretelang::CompilationOptions options,
std::vector<enum Action> actions,
size_t stackSizeRequirement = 0) {
size_t stackSizeRequirement = 0,
int num_iterations = 0) {
auto optionsName = getOptionsName(options);
for (auto description : descriptions) {
options.mainFuncName = "main";
Expand Down Expand Up @@ -137,10 +144,12 @@ void registerEndToEndBenchmark(std::string suiteName,
});
break;
case Action::EVALUATE:
benchmark::RegisterBenchmark(benchName("evaluate").c_str(),
[=](::benchmark::State &st) {
BM_Evaluate(st, description, options);
});
auto bench = benchmark::RegisterBenchmark(
benchName("evaluate").c_str(), [=](::benchmark::State &st) {
BM_Evaluate(st, description, options);
});
if (num_iterations)
bench->Iterations(num_iterations);
break;
}
}
Expand Down Expand Up @@ -180,9 +189,10 @@ int main(int argc, char **argv) {
auto suiteName = llvm::sys::path::stem(descFile.path).str();
registerEndToEndBenchmark(suiteName, descFile.descriptions,
std::get<0>(options).compilationOptions, actions,
stackSizeRequirement);
stackSizeRequirement, std::get<0>(options).numIterations);
}
::benchmark::RunSpecifiedBenchmarks();
::benchmark::Shutdown();
_dfr_terminate();
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const double TEST_ERROR_RATE = 1.0 - 0.999936657516;
typedef struct EndToEndTestOptions {
mlir::concretelang::CompilationOptions compilationOptions;
int numberOfRetry;
int numIterations;
} EndToEndTestOptions;

/// @brief Parse the command line and return a tuple contains the compilation
Expand Down Expand Up @@ -75,6 +76,18 @@ parseEndToEndCommandLine(int argc, char **argv) {
llvm::cl::desc("Enable the compression of input ciphertext"),
llvm::cl::init(false));

llvm::cl::opt<bool> distBenchmark(
"distributed",
llvm::cl::desc("Force a constant number of iterations in the benchmark "
"suite as required for distributed execution (default: 1 "
"- use --iterations=<n> to change)"),
llvm::cl::init(false));
llvm::cl::opt<int> numIterations(
"iterations",
llvm::cl::desc("Set the number of iterations for the benchmark suite "
"(only to be used with --distributed)"),
llvm::cl::init(1));

// Optimizer options
llvm::cl::opt<int> securityLevel(
"security-level",
Expand Down Expand Up @@ -143,10 +156,13 @@ parseEndToEndCommandLine(int argc, char **argv) {
f.descriptions = loadEndToEndDesc(descFile);
parsedDescriptionFiles.push_back(f);
}
int num_iterations =
(distBenchmark.getValue()) ? numIterations.getValue() : 0;
return std::make_pair(
EndToEndTestOptions{
compilationOptions,
retryFailingTests.getValue(),
num_iterations,
},
parsedDescriptionFiles);
}
Expand Down

0 comments on commit 920fe72

Please sign in to comment.