Skip to content

Commit

Permalink
[impl] adding sumcheck-cuda (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
slzk authored Oct 28, 2024
1 parent 89b7452 commit a08d51a
Show file tree
Hide file tree
Showing 53 changed files with 6,536 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sumcheck/cuda/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# IDE
.idea
.vscode

# CMake
.cmake
CMakeCache.txt
CMakeFiles
cmake_install.cmake
*.bin
Testing

# Artifact
*.log
*.csv
661 changes: 661 additions & 0 deletions sumcheck/cuda/LICENSE

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions sumcheck/cuda/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Compiler and flags
NVCC := nvcc
NVCC_FLAGS := -O3 -arch=native -std=c++17
INCLUDE_FLAGS := -Iinclude -Iicicle

# Target executable
TARGET := sumcheck.bin

# Source file
SRC := src/sumcheck_cuda.cu

# Default field to use
USE_FIELD ?= useM31ext3

# Set the minimum required version
NVCC_MIN_VERSION := 12.5

# Get NVCC version
NVCC_VERSION := $(shell $(NVCC) --version | grep -oP 'release \K[0-9]+\.[0-9]+' | head -n 1)
NVCC_MAJOR := $(shell echo $(NVCC_VERSION) | cut -d. -f1)
NVCC_MINOR := $(shell echo $(NVCC_VERSION) | cut -d. -f2)

# Phony targets
.PHONY: all check_nvcc clean

# Default target
all: check_nvcc $(TARGET)

# Check NVCC version
check_nvcc:
@if ! command -v $(NVCC) > /dev/null 2>&1; then \
echo "Error: $(NVCC) is not installed."; \
exit 1; \
elif [ $(NVCC_MAJOR) -lt 12 ] || ([ $(NVCC_MAJOR) -eq 12 ] && [ $(NVCC_MINOR) -lt 5 ]); then \
echo "Error: $(NVCC) version must be >= $(NVCC_MIN_VERSION). Current version: $(NVCC_VERSION)"; \
exit 1; \
else \
echo "$(NVCC) version $(NVCC_VERSION) is sufficient."; \
fi

# Build target
$(TARGET): $(SRC)
$(NVCC) $(NVCC_FLAGS) $(INCLUDE_FLAGS) -D$(USE_FIELD) -o $@ $<

# Clean build
clean:
rm -f $(TARGET)
52 changes: 52 additions & 0 deletions sumcheck/cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Sumcheck GPU Acceleration

This project implements GPU acceleration for the sumcheck protocol. The core computation leverages CUDA, and users can choose between CPU and GPU modes for computation. The field operations for BN254 and M31 extensions are supported, with `M31ext3` as the default field.

## Installation

Make sure you have CUDA installed on your system.

### Compile the Project

To compile the project, simply run:

```bash
make clean && make
```

This will clean any existing binaries and generate a new one: `sumcheck.bin`.

## Usage

To run the program, use the following syntax:

```bash
./sumcheck.bin -m [cpu|gpu] -p [2^(size) of circuit] [-v]
```

For example, run 2^23 sumcheck on GPU, you can use

```bash
./sumcheck.bin -m gpu -p 23
```

### Options:
- `-m [cpu|gpu]`: Choose the computation mode. Default is `cpu`.
- `-p [circuit size]`: Specify the size of the circuit in powers of 2. Default is 20.
- `-v`: Enable verbose mode for detailed output.

## Field Support

The project supports different field operations based on compile-time flags:
- **BN254**: We use Ingonyama's Icicle as the underlying implementation for BN254 field operations.
- **M31ext3**: Default mode uses M31ext3 extension field.

To switch between fields, adjust the `USE_FIELD` variable in the `Makefile`. For example, to use BN254:

```bash
make clean && make USE_FIELD=useBN254
```

## Acknowledgments

We would like to express our sincere thanks to Ingonyama for providing the [Icicle framework](https://github.com/ingonyama-zk/icicle), which is used as the underlying implementation for BN254 field operations.
47 changes: 47 additions & 0 deletions sumcheck/cuda/icicle/curves/affine.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include "gpu-utils/sharedmem.cuh"
#include "gpu-utils/modifiers.cuh"
#include <iostream>

template <class FF>
class Affine
{
public:
FF x;
FF y;

static HOST_DEVICE_INLINE Affine neg(const Affine& point) { return {point.x, FF::neg(point.y)}; }

static HOST_DEVICE_INLINE Affine zero() { return {FF::zero(), FF::zero()}; }

static HOST_DEVICE_INLINE Affine to_montgomery(const Affine& point)
{
return {FF::to_montgomery(point.x), FF::to_montgomery(point.y)};
}

static HOST_DEVICE_INLINE Affine from_montgomery(const Affine& point)
{
return {FF::from_montgomery(point.x), FF::from_montgomery(point.y)};
}

friend HOST_DEVICE_INLINE bool operator==(const Affine& xs, const Affine& ys)
{
return (xs.x == ys.x) && (xs.y == ys.y);
}

friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Affine& point)
{
os << "x: " << point.x << "; y: " << point.y;
return os;
}
};

template <class FF>
struct SharedMemory<Affine<FF>> {
__device__ Affine<FF>* getPointer()
{
extern __shared__ Affine<FF> s_affine_[];
return s_affine_;
}
};
34 changes: 34 additions & 0 deletions sumcheck/cuda/icicle/curves/curve_config.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once
#ifndef CURVE_CONFIG_H
#define CURVE_CONFIG_H

#include "fields/id.h"
#include "curves/projective.cuh"

/**
* @namespace curve_config
* Namespace with type definitions for short Weierstrass pairing-friendly [elliptic
* curves](https://hyperelliptic.org/EFD/g1p/auto-shortw.html). Here, concrete types are created in accordance
* with the `-DCURVE` env variable passed during build.
*/
#if CURVE_ID == BN254
#include "curves/params/bn254.cuh"
namespace curve_config = bn254;

#elif CURVE_ID == BLS12_381
#include "curves/params/bls12_381.cuh"
namespace curve_config = bls12_381;

#elif CURVE_ID == BLS12_377
#include "curves/params/bls12_377.cuh"
namespace curve_config = bls12_377;

#elif CURVE_ID == BW6_761
#include "curves/params/bw6_761.cuh"
namespace curve_config = bw6_761;

#elif CURVE_ID == GRUMPKIN
#include "curves/params/grumpkin.cuh"
namespace curve_config = grumpkin;
#endif
#endif
42 changes: 42 additions & 0 deletions sumcheck/cuda/icicle/curves/macro.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once
#ifndef CURVE_MACRO_H
#define CURVE_MACRO_H

#define CURVE_DEFINITIONS \
/** \
* Base field of G1 curve. Is always a prime field. \
*/ \
typedef Field<fq_config> point_field_t; \
\
static constexpr point_field_t generator_x = point_field_t{g1_gen_x}; \
static constexpr point_field_t generator_y = point_field_t{g1_gen_y}; \
static constexpr point_field_t b = point_field_t{weierstrass_b}; \
/** \
* [Projective representation](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) \
* of G1 curve consisting of three coordinates of type [point_field_t](point_field_t). \
*/ \
typedef Projective<point_field_t, scalar_t, b, generator_x, generator_y> projective_t; \
/** \
* Affine representation of G1 curve consisting of two coordinates of type [point_field_t](point_field_t). \
*/ \
typedef Affine<point_field_t> affine_t;

#define G2_CURVE_DEFINITIONS \
typedef ExtensionField<fq_config, point_field_t> g2_point_field_t; \
static constexpr g2_point_field_t g2_generator_x = \
g2_point_field_t{point_field_t{g2_gen_x_re}, point_field_t{g2_gen_x_im}}; \
static constexpr g2_point_field_t g2_generator_y = \
g2_point_field_t{point_field_t{g2_gen_y_re}, point_field_t{g2_gen_y_im}}; \
static constexpr g2_point_field_t g2_b = \
g2_point_field_t{point_field_t{weierstrass_b_g2_re}, point_field_t{weierstrass_b_g2_im}}; \
\
/** \
* [Projective representation](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) of G2 curve. \
*/ \
typedef Projective<g2_point_field_t, scalar_t, g2_b, g2_generator_x, g2_generator_y> g2_projective_t; \
/** \
* Affine representation of G1 curve. \
*/ \
typedef Affine<g2_point_field_t> g2_affine_t;

#endif
48 changes: 48 additions & 0 deletions sumcheck/cuda/icicle/curves/params/bls12_377.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once
#ifndef BLS12_377_PARAMS_H
#define BLS12_377_PARAMS_H

#include "fields/storage.cuh"

#include "curves/macro.h"
#include "curves/projective.cuh"
#include "fields/snark_fields/bls12_377_base.cuh"
#include "fields/snark_fields/bls12_377_scalar.cuh"
#include "fields/quadratic_extension.cuh"

namespace bls12_377 {
// G1 and G2 generators
static constexpr storage<fq_config::limbs_count> g1_gen_x = {0xb21be9ef, 0xeab9b16e, 0xffcd394e, 0xd5481512,
0xbd37cb5c, 0x188282c8, 0xaa9d41bb, 0x85951e2c,
0xbf87ff54, 0xc8fc6225, 0xfe740a67, 0x008848de};
static constexpr storage<fq_config::limbs_count> g1_gen_y = {0x559c8ea6, 0xfd82de55, 0x34a9591a, 0xc2fe3d36,
0x4fb82305, 0x6d182ad4, 0xca3e52d9, 0xbd7fb348,
0x30afeec4, 0x1f674f5d, 0xc5102eff, 0x01914a69};
static constexpr storage<fq_config::limbs_count> g2_gen_x_re = {0x7c005196, 0x74e3e48f, 0xbb535402, 0x71889f52,
0x57db6b9b, 0x7ea501f5, 0x203e5031, 0xc565f071,
0xa3841d01, 0xc89630a2, 0x71c785fe, 0x018480be};
static constexpr storage<fq_config::limbs_count> g2_gen_x_im = {0x6ea16afe, 0xb26bfefa, 0xbff76fe6, 0x5cf89984,
0x0799c9de, 0xe7223ece, 0x6651cecb, 0x532777ee,
0xb1b140d5, 0x70dc5a51, 0xe7004031, 0x00ea6040};
static constexpr storage<fq_config::limbs_count> g2_gen_y_re = {0x09fd4ddf, 0xf0940944, 0x6d8c7c2e, 0xf2cf8888,
0xf832d204, 0xe458c282, 0x74b49a58, 0xde03ed72,
0xcbb2efb4, 0xd960736b, 0x5d446f7b, 0x00690d66};
static constexpr storage<fq_config::limbs_count> g2_gen_y_im = {0x85eb8f93, 0xd9a1cdd1, 0x5e52270b, 0x4279b83f,
0xcee304c2, 0x2463b01a, 0x3d591bf1, 0x61ef11ac,
0x151a70aa, 0x9e549da3, 0xd2835518, 0x00f8169f};

static constexpr storage<fq_config::limbs_count> weierstrass_b = {0x00000001, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_re = {
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_im = {
0x9999999a, 0x1c9ed999, 0x1ccccccd, 0x0dd39e5c, 0x3c6bf800, 0x129207b6,
0xcd5fd889, 0xdc7b4f91, 0x7460c589, 0x43bd0373, 0xdb0fd6f3, 0x010222f6};

CURVE_DEFINITIONS
G2_CURVE_DEFINITIONS
} // namespace bls12_377

#endif
48 changes: 48 additions & 0 deletions sumcheck/cuda/icicle/curves/params/bls12_381.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once
#ifndef BLS12_381_PARAMS_H
#define BLS12_381_PARAMS_H

#include "fields/storage.cuh"

#include "curves/macro.h"
#include "curves/projective.cuh"
#include "fields/snark_fields/bls12_381_base.cuh"
#include "fields/snark_fields/bls12_381_scalar.cuh"
#include "fields/quadratic_extension.cuh"

namespace bls12_381 {
// G1 and G2 generators
static constexpr storage<fq_config::limbs_count> g1_gen_x = {0xdb22c6bb, 0xfb3af00a, 0xf97a1aef, 0x6c55e83f,
0x171bac58, 0xa14e3a3f, 0x9774b905, 0xc3688c4f,
0x4fa9ac0f, 0x2695638c, 0x3197d794, 0x17f1d3a7};
static constexpr storage<fq_config::limbs_count> g1_gen_y = {0x46c5e7e1, 0x0caa2329, 0xa2888ae4, 0xd03cc744,
0x2c04b3ed, 0x00db18cb, 0xd5d00af6, 0xfcf5e095,
0x741d8ae4, 0xa09e30ed, 0xe3aaa0f1, 0x08b3f481};
static constexpr storage<fq_config::limbs_count> g2_gen_x_re = {0xc121bdb8, 0xd48056c8, 0xa805bbef, 0x0bac0326,
0x7ae3d177, 0xb4510b64, 0xfa403b02, 0xc6e47ad4,
0x2dc51051, 0x26080527, 0xf08f0a91, 0x024aa2b2};
static constexpr storage<fq_config::limbs_count> g2_gen_x_im = {0x5d042b7e, 0xe5ac7d05, 0x13945d57, 0x334cf112,
0xdc7f5049, 0xb5da61bb, 0x9920b61a, 0x596bd0d0,
0x88274f65, 0x7dacd3a0, 0x52719f60, 0x13e02b60};
static constexpr storage<fq_config::limbs_count> g2_gen_y_re = {0x08b82801, 0xe1935486, 0x3baca289, 0x923ac9cc,
0x5160d12c, 0x6d429a69, 0x8cbdd3a7, 0xadfd9baa,
0xda2e351a, 0x8cc9cdc6, 0x727d6e11, 0x0ce5d527};
static constexpr storage<fq_config::limbs_count> g2_gen_y_im = {0xf05f79be, 0xaaa9075f, 0x5cec1da1, 0x3f370d27,
0x572e99ab, 0x267492ab, 0x85a763af, 0xcb3e287e,
0x2bc28b99, 0x32acd2b0, 0x2ea734cc, 0x0606c4a0};

static constexpr storage<fq_config::limbs_count> weierstrass_b = {0x00000004, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_re = {
0x00000004, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_im = {
0x00000004, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000};

CURVE_DEFINITIONS
G2_CURVE_DEFINITIONS
} // namespace bls12_381

#endif
39 changes: 39 additions & 0 deletions sumcheck/cuda/icicle/curves/params/bn254.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once
#ifndef BN254_PARAMS_H
#define BN254_PARAMS_H

#include "fields/storage.cuh"

#include "curves/macro.h"
#include "curves/projective.cuh"
#include "fields/snark_fields/bn254_base.cuh"
#include "fields/snark_fields/bn254_scalar.cuh"
#include "fields/quadratic_extension.cuh"

namespace bn254 {
// G1 and G2 generators
static constexpr storage<fq_config::limbs_count> g1_gen_x = {0x00000001, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> g1_gen_y = {0x00000002, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> g2_gen_x_re = {0xd992f6ed, 0x46debd5c, 0xf75edadd, 0x674322d4,
0x5e5c4479, 0x426a0066, 0x121f1e76, 0x1800deef};
static constexpr storage<fq_config::limbs_count> g2_gen_x_im = {0xaef312c2, 0x97e485b7, 0x35a9e712, 0xf1aa4933,
0x31fb5d25, 0x7260bfb7, 0x920d483a, 0x198e9393};
static constexpr storage<fq_config::limbs_count> g2_gen_y_re = {0x66fa7daa, 0x4ce6cc01, 0x0c43d37b, 0xe3d1e769,
0x8dcb408f, 0x4aab7180, 0xdb8c6deb, 0x12c85ea5};
static constexpr storage<fq_config::limbs_count> g2_gen_y_im = {0xd122975b, 0x55acdadc, 0x70b38ef3, 0xbc4b3133,
0x690c3395, 0xec9e99ad, 0x585ff075, 0x090689d0};

static constexpr storage<fq_config::limbs_count> weierstrass_b = {0x00000003, 0x00000000, 0x00000000, 0x00000000,
0x00000000, 0x00000000, 0x00000000, 0x00000000};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_re = {
0x24a138e5, 0x3267e6dc, 0x59dbefa3, 0xb5b4c5e5, 0x1be06ac3, 0x81be1899, 0xceb8aaae, 0x2b149d40};
static constexpr storage<fq_config::limbs_count> weierstrass_b_g2_im = {
0x85c315d2, 0xe4a2bd06, 0xe52d1852, 0xa74fa084, 0xeed8fdf4, 0xcd2cafad, 0x3af0fed4, 0x009713b0};

CURVE_DEFINITIONS
G2_CURVE_DEFINITIONS
} // namespace bn254

#endif
Loading

0 comments on commit a08d51a

Please sign in to comment.