Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use ec_add & ec_double from compiled circuits #709

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 7 additions & 152 deletions python/cairo-ec/src/cairo_ec/ec_ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from starkware.cairo.common.registers import get_label_location
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc

from cairo_ec.curve.g1_point import G1Point
from cairo_ec.circuits.ec_ops_compiled import ec_add as ec_add_unchecked, ec_double
from cairo_ec.uint384 import uint384_is_neg_mod_p, uint384_eq_mod_p, felt_to_uint384
from cairo_ec.circuits.ec_ops_compiled import assert_is_on_curve

Expand Down Expand Up @@ -57,96 +58,13 @@ func get_random_point{
return get_random_point(seed=seed, a=a, b=b, g=g, p=p);
}

// Add Double an EC point. Doesn't check if the input is on curve nor if it's the point at infinity.
func ec_double{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
p: G1Point, g: UInt384, a: UInt384, modulus: UInt384
) -> G1Point {
alloc_locals;

let add_mod_n = 6;
let (add_offsets_ptr) = get_label_location(ec_double_add_offsets_label);
let mul_mod_n = 5;
let (mul_offsets_ptr) = get_label_location(ec_double_mul_offsets_label);

let input: UInt384* = cast(range_check96_ptr, UInt384*);
assert input[0] = g;
assert input[1] = p.x;
assert input[2] = p.y;
assert input[3] = a;

assert add_mod_ptr[0] = ModBuiltin(
p=modulus, values_ptr=input, offsets_ptr=add_offsets_ptr, n=add_mod_n
);
assert mul_mod_ptr[0] = ModBuiltin(
p=modulus, values_ptr=input, offsets_ptr=mul_offsets_ptr, n=mul_mod_n
);

%{
from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner
assert builtin_runners["add_mod_builtin"].instance_def.batch_size == 1
assert builtin_runners["mul_mod_builtin"].instance_def.batch_size == 1

ModBuiltinRunner.fill_memory(
memory=memory,
add_mod=(ids.add_mod_ptr.address_, builtin_runners["add_mod_builtin"], ids.add_mod_n),
mul_mod=(ids.mul_mod_ptr.address_, builtin_runners["mul_mod_builtin"], ids.mul_mod_n),
)
%}

let add_mod_ptr = &add_mod_ptr[add_mod_n];
let mul_mod_ptr = &mul_mod_ptr[mul_mod_n];
let res = G1Point(
x=[cast(range_check96_ptr + 44, UInt384*)], y=[cast(range_check96_ptr + 56, UInt384*)]
);
let range_check96_ptr = range_check96_ptr + 60; // 56 is the last start index in the offset_ptr array

return res;

ec_double_add_offsets_label:
dw 20;
dw 12;
dw 24;
dw 8;
dw 8;
dw 28;
dw 4;
dw 40;
dw 36;
dw 4;
dw 44;
dw 40;
dw 44;
dw 48;
dw 4;
dw 8;
dw 56;
dw 52;

ec_double_mul_offsets_label:
dw 4;
dw 4;
dw 16;
dw 0;
dw 16;
dw 20;
dw 28;
dw 32;
dw 24;
dw 32;
dw 32;
dw 36;
dw 32;
dw 48;
dw 52;
}

// Add two EC points. Doesn't check if the inputs are on curve nor if they are the point at infinity.
func ec_add{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
p: G1Point, q: G1Point, g: UInt384, a: UInt384, modulus: UInt384
) -> G1Point {
alloc_locals;
let same_x = uint384_eq_mod_p(p.x, q.x, modulus);

let (__fp__, __pc__) = get_fp_and_pc();
if (same_x != 0) {
let opposite_y = uint384_is_neg_mod_p(p.y, q.y, modulus);
if (opposite_y != 0) {
Expand All @@ -155,77 +73,14 @@ func ec_add{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: Mod
return res;
}

return ec_double(p, g, a, modulus);
let (res_x, res_y) = ec_double(&p.x, &p.y, &a, &modulus);
let res = G1Point(x=[res_x], y=[res_y]);
return res;
}

let add_mod_n = 6;
let (add_offsets_ptr) = get_label_location(ec_add_add_offsets_label);
let mul_mod_n = 3;
let (mul_offsets_ptr) = get_label_location(ec_add_mul_offsets_label);
let input: UInt384* = cast(range_check96_ptr, UInt384*);
assert input[0] = p.x;
assert input[1] = p.y;
assert input[2] = q.x;
assert input[3] = q.y;

assert add_mod_ptr[0] = ModBuiltin(
p=modulus, values_ptr=input, offsets_ptr=add_offsets_ptr, n=add_mod_n
);
assert mul_mod_ptr[0] = ModBuiltin(
p=modulus, values_ptr=input, offsets_ptr=mul_offsets_ptr, n=mul_mod_n
);

%{
from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner
assert builtin_runners["add_mod_builtin"].instance_def.batch_size == 1
assert builtin_runners["mul_mod_builtin"].instance_def.batch_size == 1

ModBuiltinRunner.fill_memory(
memory=memory,
add_mod=(ids.add_mod_ptr.address_, builtin_runners["add_mod_builtin"], ids.add_mod_n),
mul_mod=(ids.mul_mod_ptr.address_, builtin_runners["mul_mod_builtin"], ids.mul_mod_n),
)
%}

let add_mod_ptr = &add_mod_ptr[add_mod_n];
let mul_mod_ptr = &mul_mod_ptr[mul_mod_n];
let range_check96_ptr = range_check96_ptr + 52; // 48 is the last start index in the offset_ptr array

let res = G1Point(
x=[cast(cast(input, felt*) + 36, UInt384*)], y=[cast(cast(input, felt*) + 48, UInt384*)]
);
let (res_x, res_y) = ec_add_unchecked(&p.x, &p.y, &q.x, &q.y, &modulus);
let res = G1Point(x=[res_x], y=[res_y]);
return res;

ec_add_add_offsets_label:
dw 12;
dw 16;
dw 4;
dw 8;
dw 20;
dw 0;
dw 0;
dw 32;
dw 28;
dw 8;
dw 36;
dw 32;
dw 36;
dw 40;
dw 0;
dw 4;
dw 48;
dw 44;

ec_add_mul_offsets_label:
dw 20;
dw 24;
dw 16;
dw 24;
dw 24;
dw 28;
dw 24;
dw 40;
dw 44;
}

// Multiply an EC point by a scalar. Doesn't check if the input is on curve nor if it's the point at infinity.
Expand Down
26 changes: 1 addition & 25 deletions python/cairo-ec/tests/test_ec_ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.cairo_builtins import ModBuiltin, PoseidonBuiltin, UInt384
from starkware.cairo.common.uint256 import Uint256

from cairo_ec.ec_ops import ec_double, ec_add, try_get_point_from_x, get_random_point
from cairo_ec.ec_ops import ec_add, try_get_point_from_x, get_random_point
from cairo_ec.curve.g1_point import G1Point

func test__try_get_point_from_x{
Expand Down Expand Up @@ -59,30 +59,6 @@ func test__get_random_point{
return point_ptr;
}

func test__ec_double{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
) -> G1Point* {
alloc_locals;
let (p_ptr: G1Point*) = alloc();

let (a_ptr: UInt384*) = alloc();
let (g_ptr: UInt384*) = alloc();
let (modulus_ptr: UInt384*) = alloc();
%{
segments.write_arg(ids.p_ptr.address_, program_input["p"])
segments.write_arg(ids.a_ptr.address_, program_input["a"])
segments.write_arg(ids.g_ptr.address_, program_input["g"])
segments.write_arg(ids.modulus_ptr.address_, program_input["modulus"])
%}

let res = ec_double([p_ptr], [g_ptr], [a_ptr], [modulus_ptr]);

tempvar res_ptr = new G1Point(
UInt384(res.x.d0, res.x.d1, res.x.d2, res.x.d3),
UInt384(res.y.d0, res.y.d1, res.y.d2, res.y.d3),
);
return res_ptr;
}

func test__ec_add{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
) -> G1Point* {
alloc_locals;
Expand Down
76 changes: 25 additions & 51 deletions python/cairo-ec/tests/test_ec_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,69 +65,43 @@ def test_should_return_a_point_on_the_curve(self, cairo_run, seed, curve):
x**3 + curve.A * x + curve.B
) % curve.FIELD.PRIME == y**2 % curve.FIELD.PRIME

class TestEcDouble:
@given(seed=felt, curve=curve)
def test_ec_double(self, cairo_run, seed, curve):
point = cairo_run(
"test__get_random_point",
seed=seed,
class TestEcAdd:
@given(curve=curve)
def test_ec_add(self, cairo_run, curve):
p = curve.random_point()
q = curve.random_point()
res = cairo_run(
"test__ec_add",
p=[*int_to_uint384(int(p.x)), *int_to_uint384(int(p.y))],
q=[*int_to_uint384(int(q.x)), *int_to_uint384(int(q.y))],
a=int_to_uint384(int(curve.A)),
b=int_to_uint384(int(curve.B)),
g=int_to_uint384(int(curve.G)),
p=int_to_uint384(int(curve.FIELD.PRIME)),
)
x = curve.FIELD(
uint384_to_int(
point["x"]["d0"],
point["x"]["d1"],
point["x"]["d2"],
point["x"]["d3"],
)
modulus=int_to_uint384(int(curve.FIELD.PRIME)),
)
y = curve.FIELD(
uint384_to_int(
point["y"]["d0"],
point["y"]["d1"],
point["y"]["d2"],
point["y"]["d3"],
)
assert p + q == curve(
*[curve.FIELD(uint384_to_int(**i)) for i in res.values()]
)
double = cairo_run(
"test__ec_double",
p=(
point["x"]["d0"],
point["x"]["d1"],
point["x"]["d2"],
point["x"]["d3"],
point["y"]["d0"],
point["y"]["d1"],
point["y"]["d2"],
point["y"]["d3"],
),

@given(curve=curve)
def test_ec_add_equal(self, cairo_run, curve):
p = curve.random_point()
q = curve(p.x, p.y)
res = cairo_run(
"test__ec_add",
p=[*int_to_uint384(int(p.x)), *int_to_uint384(int(p.y))],
q=[*int_to_uint384(int(q.x)), *int_to_uint384(int(q.y))],
a=int_to_uint384(int(curve.A)),
g=int_to_uint384(int(curve.G)),
modulus=int_to_uint384(int(curve.FIELD.PRIME)),
)
assert curve(x, y).double() == curve(
uint384_to_int(
double["x"]["d0"],
double["x"]["d1"],
double["x"]["d2"],
double["x"]["d3"],
),
uint384_to_int(
double["y"]["d0"],
double["y"]["d1"],
double["y"]["d2"],
double["y"]["d3"],
),
assert p + q == curve(
*[curve.FIELD(uint384_to_int(**i)) for i in res.values()]
)

class TestEcAdd:
@given(curve=curve)
def test_ec_add(self, cairo_run, curve):
def test_ec_add_opposite(self, cairo_run, curve):
p = curve.random_point()
q = curve.random_point()
q = curve(p.x, -p.y)
res = cairo_run(
"test__ec_add",
p=[*int_to_uint384(int(p.x)), *int_to_uint384(int(p.y))],
Expand Down