Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSM optimizations #20

Merged
merged 12 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ jobs:
- name: Build prover Android ARM64
run: |
mkdir -p build_prover_android && cd build_prover_android
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DUSE_OPENMP=OFF
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
make -j4 && make install

- name: Build prover Android ARM64 with OpenMP
run: |
mkdir -p build_prover_android_openmp && cd build_prover_android_openmp
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DUSE_OPENMP=ON
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
make -j4 && make install

- name: Build prover Android x86_64
run: |
mkdir -p build_prover_android_x86_64 && cd build_prover_android_x86_64
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DUSE_OPENMP=OFF
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
make -j4 && make install

- name: Build prover Android x86_64 with OpenMP
run: |
mkdir -p build_prover_android_openmp_x86_64 && cd build_prover_android_openmp_x86_64
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DUSE_OPENMP=ON
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
make -j4 && make install

- name: Build prover Linux
Expand Down Expand Up @@ -184,13 +184,13 @@ jobs:
if [[ ! -d "depends/gmp/package_macos_arm64" ]]; then ./build_gmp.sh macos_arm64; fi

mkdir -p build_prover_ios && cd build_prover_ios
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios -DBUILD_TESTS=OFF
xcodebuild -destination 'generic/platform=iOS' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj -configuration Release
cp ../depends/gmp/package_ios_arm64/lib/libgmp.a src/Release-iphoneos
cd ../

mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO
mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO -DBUILD_TESTS=OFF
xcodebuild -destination 'generic/platform=iOS Simulator' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj
cp ../depends/gmp/package_iphone_simulator/lib/libgmp.a src/Debug-iphonesimulator
cd ../
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ project(rapidsnark LANGUAGES CXX C ASM)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

message("BITS_PER_CHUNK=" ${BITS_PER_CHUNK})
message("USE_ASM=" ${USE_ASM})
message("USE_OPENMP=" ${USE_OPENMP})
message("CMAKE_CROSSCOMPILING=" ${CMAKE_CROSSCOMPILING})
Expand Down
2 changes: 1 addition & 1 deletion depends/ffiasm
17 changes: 12 additions & 5 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if(USE_ASM)
endif()
endif()

if(DEFINED BITS_PER_CHUNK)
add_definitions(-DMSM_BITS_PER_CHUNK=${BITS_PER_CHUNK})
endif()

if(USE_ASM AND ARCH MATCHES "x86_64")

if (CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin")
Expand Down Expand Up @@ -131,12 +135,15 @@ if(USE_SODIUM)
target_link_libraries(prover sodium)
endif()

option(BUILD_TESTS "Build the tests" ON)

enable_testing()
add_executable(test_public_size test_public_size.c)
target_link_libraries(test_public_size rapidsnarkStaticFrFq)
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
if(BUILD_TESTS)
enable_testing()
add_executable(test_public_size test_public_size.c)
target_link_libraries(test_public_size rapidsnarkStaticFrFq pthread)
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
endif()

if(OpenMP_CXX_FOUND)

Expand Down
173 changes: 72 additions & 101 deletions src/groth16.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "random_generator.hpp"
#include "logging.hpp"
#include <future>
#include "misc.hpp"
#include <vector>
#include <mutex>

namespace Groth16 {

Expand Down Expand Up @@ -46,114 +48,84 @@ std::unique_ptr<Prover<Engine>> makeProver(
template <typename Engine>
std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement *wtns) {

#ifdef USE_OPENMP
ThreadPool &threadPool = ThreadPool::defaultPool();

LOG_TRACE("Start Multiexp A");
uint32_t sW = sizeof(wtns[0]);
typename Engine::G1Point pi_a;
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
E.g1.multiMulByScalarMSM(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss2;
ss2 << "pi_a: " << E.g1.toString(pi_a);
LOG_DEBUG(ss2);

LOG_TRACE("Start Multiexp B1");
typename Engine::G1Point pib1;
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
E.g1.multiMulByScalarMSM(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss3;
ss3 << "pib1: " << E.g1.toString(pib1);
LOG_DEBUG(ss3);

LOG_TRACE("Start Multiexp B2");
typename Engine::G2Point pi_b;
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
E.g2.multiMulByScalarMSM(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss4;
ss4 << "pi_b: " << E.g2.toString(pi_b);
LOG_DEBUG(ss4);

LOG_TRACE("Start Multiexp C");
typename Engine::G1Point pi_c;
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
E.g1.multiMulByScalarMSM(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
std::ostringstream ss5;
ss5 << "pi_c: " << E.g1.toString(pi_c);
LOG_DEBUG(ss5);
#else
LOG_TRACE("Start Multiexp A");
uint32_t sW = sizeof(wtns[0]);
typename Engine::G1Point pi_a;
auto pA_future = std::async([&]() {
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp B1");
typename Engine::G1Point pib1;
auto pB1_future = std::async([&]() {
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp B2");
typename Engine::G2Point pi_b;
auto pB2_future = std::async([&]() {
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp C");
typename Engine::G1Point pi_c;
auto pC_future = std::async([&]() {
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
});
#endif

LOG_TRACE("Start Initializing a b c A");
auto a = new typename Engine::FrElement[domainSize];
auto b = new typename Engine::FrElement[domainSize];
auto c = new typename Engine::FrElement[domainSize];

#pragma omp parallel for
for (u_int32_t i=0; i<domainSize; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int32_t i=begin; i<end; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
}
});

LOG_TRACE("Processing coefs");
#ifdef _OPENMP
#define NLOCKS 1024
omp_lock_t locks[NLOCKS];
for (int i=0; i<NLOCKS; i++) omp_init_lock(&locks[i]);
#pragma omp parallel for
#endif
for (u_int64_t i=0; i<nCoefs; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;

E.fr.mul(
aux,
wtns[coefs[i].s],
coefs[i].coef
);
#ifdef _OPENMP
omp_set_lock(&locks[coefs[i].c % NLOCKS]);
#endif
E.fr.add(
ab[coefs[i].c],
ab[coefs[i].c],
aux
);
#ifdef _OPENMP
omp_unset_lock(&locks[coefs[i].c % NLOCKS]);
#endif
}
#ifdef _OPENMP
for (int i=0; i<NLOCKS; i++) omp_destroy_lock(&locks[i]);
#endif

#define NLOCKS 1024
std::vector<std::mutex> locks(NLOCKS);

threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;

E.fr.mul(
aux,
wtns[coefs[i].s],
coefs[i].coef
);

std::lock_guard<std::mutex> guard(locks[coefs[i].c % NLOCKS]);

E.fr.add(
ab[coefs[i].c],
ab[coefs[i].c],
aux
);
}
});
LOG_TRACE("Calculating c");
#pragma omp parallel for
for (u_int32_t i=0; i<domainSize; i++) {
E.fr.mul(
c[i],
a[i],
b[i]
);
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(
c[i],
a[i],
b[i]
);
}
});

LOG_TRACE("Initializing fft");
u_int32_t domainPower = fft->log2(domainSize);
Expand All @@ -164,10 +136,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
LOG_TRACE("Start Shift A");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}
});

LOG_TRACE("a After shift:");
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
Expand All @@ -182,10 +157,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
LOG_TRACE("Start Shift B");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
});
LOG_TRACE("b After shift:");
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
Expand All @@ -201,10 +177,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
LOG_TRACE("Start Shift C");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
});
LOG_TRACE("c After shift:");
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
Expand All @@ -215,12 +192,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[1]).c_str());

LOG_TRACE("Start ABC");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
E.fr.fromMontgomery(a[i], a[i]);
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
E.fr.fromMontgomery(a[i], a[i]);
}
});
LOG_TRACE("abc:");
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
Expand All @@ -230,7 +208,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement

LOG_TRACE("Start Multiexp H");
typename Engine::G1Point pih;
E.g1.multiMulByScalar(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
E.g1.multiMulByScalarMSM(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
std::ostringstream ss1;
ss1 << "pih: " << E.g1.toString(pih);
LOG_DEBUG(ss1);
Expand All @@ -247,13 +225,6 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
randombytes_buf((void *)&(r.v[0]), sizeof(r)-1);
randombytes_buf((void *)&(s.v[0]), sizeof(s)-1);

#ifndef USE_OPENMP
pA_future.get();
pB1_future.get();
pB2_future.get();
pC_future.get();
#endif

typename Engine::G1Point p1;
typename Engine::G2Point p2;

Expand Down