Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add assert_is_on_curve #693

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cairo/src/utils/signature.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ from starkware.cairo.common.alloc import alloc

from cairo_core.maths import unsigned_div_rem, assert_uint256_le
from cairo_ec.uint384 import uint384_to_uint256, uint256_to_uint384

from cairo_ec.curve.secp256k1 import (
secp256k1,
try_recover_public_key,
Expand Down
5 changes: 3 additions & 2 deletions python/cairo-addons/src/cairo_addons/hints/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def compute_y_from_x_hint(ids: VmConsts, segments: MemorySegmentManager):
x = uint384_to_int(ids.x.d0, ids.x.d1, ids.x.d2, ids.x.d3)
rhs = (x**3 + a * x + b) % p

ids.is_on_curve = is_quad_residue(rhs, p)
if ids.is_on_curve == 1:
is_on_curve = is_quad_residue(rhs, p)
if is_on_curve == 1:
zmalatrax marked this conversation as resolved.
Show resolved Hide resolved
square_root = sqrt_mod(rhs, p)
if ids.v % 2 == square_root % 2:
pass
Expand All @@ -72,6 +72,7 @@ def compute_y_from_x_hint(ids: VmConsts, segments: MemorySegmentManager):
square_root = sqrt_mod(rhs * g, p)

segments.load_data(ids.y_try.address_, int_to_uint384(square_root))
segments.load_data(ids.is_on_curve.address_, int_to_uint384(is_on_curve))


@register_hint
Expand Down
17 changes: 17 additions & 0 deletions python/cairo-ec/src/cairo_ec/circuits/ec_ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,20 @@ func ec_double(x0: felt, y0: felt, a: felt) -> (felt, felt) {

end:
}

// @dev Assert that a point is, or is not, on the curve by checking that either y is actually the square root of rhs
// (is_on_curve = True => y^2 = rhs) or y is the square root of rhs * g (is_on_curve = False => y^2 = rhs * g),
// which mean that rhs is not a quadratic residue because g * rhs is, and so that x is not on the curve.
// @param x The x coordinate of the point
// @param y The y coordinate of the point
// @param g The generator point
// @param is_on_curve True if the point is on the curve, False otherwise
func assert_is_on_curve(x: felt, y: felt, a: felt, b: felt, g: felt, is_on_curve: felt) {
assert is_on_curve * (1 - is_on_curve) = 0;
tempvar rhs = x * x * x + a * x + b;
assert y * y = rhs * is_on_curve + g * rhs * (1 - is_on_curve);

return ();

end:
}
117 changes: 117 additions & 0 deletions python/cairo-ec/src/cairo_ec/circuits/ec_ops_compiled.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,120 @@
dw 64;
dw 68;
}

func assert_is_on_curve{
range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(
x: UInt384*,
y: UInt384*,
a: UInt384*,
b: UInt384*,
g: UInt384*,
is_on_curve: UInt384*,
p: UInt384*,
) {
let (_, pc) = get_fp_and_pc();

pc_label:
let add_mod_offsets_ptr = pc + (add_offsets - pc_label);
let mul_mod_offsets_ptr = pc + (mul_offsets - pc_label);

assert [range_check96_ptr + 0] = 1;
assert [range_check96_ptr + 1] = 0;
assert [range_check96_ptr + 2] = 0;
assert [range_check96_ptr + 3] = 0;
assert [range_check96_ptr + 4] = 0;
assert [range_check96_ptr + 5] = 0;
assert [range_check96_ptr + 6] = 0;
assert [range_check96_ptr + 7] = 0;

assert [range_check96_ptr + 8] = x.d0;
assert [range_check96_ptr + 9] = x.d1;
assert [range_check96_ptr + 10] = x.d2;
assert [range_check96_ptr + 11] = x.d3;
assert [range_check96_ptr + 12] = y.d0;
assert [range_check96_ptr + 13] = y.d1;
assert [range_check96_ptr + 14] = y.d2;
assert [range_check96_ptr + 15] = y.d3;
assert [range_check96_ptr + 16] = a.d0;
assert [range_check96_ptr + 17] = a.d1;
assert [range_check96_ptr + 18] = a.d2;
assert [range_check96_ptr + 19] = a.d3;
assert [range_check96_ptr + 20] = b.d0;
assert [range_check96_ptr + 21] = b.d1;
assert [range_check96_ptr + 22] = b.d2;
assert [range_check96_ptr + 23] = b.d3;
assert [range_check96_ptr + 24] = g.d0;
assert [range_check96_ptr + 25] = g.d1;
assert [range_check96_ptr + 26] = g.d2;
assert [range_check96_ptr + 27] = g.d3;
assert [range_check96_ptr + 28] = is_on_curve.d0;
assert [range_check96_ptr + 29] = is_on_curve.d1;
assert [range_check96_ptr + 30] = is_on_curve.d2;
assert [range_check96_ptr + 31] = is_on_curve.d3;

run_mod_p_circuit(
p=[p],
values_ptr=cast(range_check96_ptr, UInt384*),
add_mod_offsets_ptr=add_mod_offsets_ptr,
add_mod_n=8,
mul_mod_offsets_ptr=mul_mod_offsets_ptr,
mul_mod_n=8,
);

let range_check96_ptr = range_check96_ptr + 88;

return ();

add_offsets:
dw 4;
dw 0;
dw 32;
dw 36;
dw 28;
dw 32;
dw 4;
dw 4;
dw 40;
dw 48;
dw 52;
dw 56;
dw 56;
dw 20;
dw 60;
dw 4;
dw 0;
dw 72;
dw 76;
dw 28;
dw 72;
dw 64;
dw 80;
dw 84;

Check warning on line 273 in python/cairo-ec/src/cairo_ec/circuits/ec_ops_compiled.cairo

View check run for this annotation

Codecov / codecov/patch

python/cairo-ec/src/cairo_ec/circuits/ec_ops_compiled.cairo#L250-L273

Added lines #L250 - L273 were not covered by tests

mul_offsets:
dw 28;
dw 36;
dw 40;
dw 8;
dw 8;
dw 44;
dw 44;
dw 8;
dw 48;
dw 16;
dw 8;
dw 52;
dw 60;
dw 28;
dw 64;
dw 24;
dw 60;
dw 68;
dw 68;
dw 76;
dw 80;
dw 12;
dw 12;
dw 84;

Check warning on line 299 in python/cairo-ec/src/cairo_ec/circuits/ec_ops_compiled.cairo

View check run for this annotation

Codecov / codecov/patch

python/cairo-ec/src/cairo_ec/circuits/ec_ops_compiled.cairo#L276-L299

Added lines #L276 - L299 were not covered by tests
}
27 changes: 20 additions & 7 deletions python/cairo-ec/src/cairo_ec/curve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@ def __init__(self, x: Union[int, F], y: Union[int, F]):
super().__init__(self.FIELD(x), self.FIELD(y))

@classmethod
def random_point(cls, x=None) -> "EllipticCurve":
"""Generate a random point on the curve.
def random_point(cls, x=None, retry=True) -> "EllipticCurve":
"""Generate a random point.

If retry is True, the returned point is guaranteed to be on the curve.
Otherwise, it just returns the first point it finds, which might not be on the curve.

Returns a random point (x,y) satisfying y² = x³ + ax + b.
Uses try-and-increment method:
1. Pick random x
2. Compute x³ + ax + b
3. If it's a quadratic residue, compute y
4. If not, try another x
4. If not, and retry is True, try another x
5. If not, and retry is False, return (x, sqrt(x³ + ax + b) * g)
"""
while True:
# Random x in the field
Expand All @@ -50,9 +53,19 @@ def random_point(cls, x=None) -> "EllipticCurve":
# Randomly choose between y and -y
if randint(0, 1):
y = -y
return cls(cls.FIELD(x), cls.FIELD(y))
else:
x = cls.G * x
return cls.__new__(cls, cls.FIELD(x), cls.FIELD(y))
if not retry:
y = sqrt_mod(rhs * cls.G, cls.FIELD.PRIME)
return cls.__new__(cls, cls.FIELD(x), cls.FIELD(y))

x = cls.G * x

@classmethod
def is_on_curve(cls, x: int, y: int) -> bool:
"""Check if a point is on the curve."""
y = cls.FIELD(y)
x = cls.FIELD(x)
return y**2 == x**3 + cls.A * x + cls.B


class Secp256k1P(PrimeField):
Expand Down
4 changes: 2 additions & 2 deletions python/cairo-ec/src/cairo_ec/curve/secp256k1.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ func try_recover_public_key{
local g: UInt384 = UInt384(secp256k1.G0, secp256k1.G1, secp256k1.G2, secp256k1.G3);
local p: UInt384 = UInt384(secp256k1.P0, secp256k1.P1, secp256k1.P2, secp256k1.P3);

let (y, is_on_curve) = try_get_point_from_x(x=r, v=y_parity, a=&a, b=&b, g=&g, p=&p);
let (y, is_on_curve) = try_get_point_from_x(x=&r, v=y_parity, a=&a, b=&b, g=&g, p=&p);
if (is_on_curve == 0) {
return (public_key_point=G1Point(x=UInt384(0, 0, 0, 0), y=UInt384(0, 0, 0, 0)), success=0);
}
let r_point = G1Point(x=r, y=y);
let r_point = G1Point(x=r, y=[y]);
// The result is given by
// -(msg_hash / r) * gen + (s / r) * r_point
// where the division by r is modulo N.
Expand Down
101 changes: 16 additions & 85 deletions python/cairo-ec/src/cairo_ec/ec_ops.cairo
Original file line number Diff line number Diff line change
@@ -1,103 +1,33 @@
from starkware.cairo.common.cairo_builtins import UInt384, ModBuiltin, PoseidonBuiltin
from starkware.cairo.common.registers import get_label_location
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc

from cairo_ec.curve.g1_point import G1Point
from cairo_ec.uint384 import uint384_is_neg_mod_p, uint384_eq_mod_p, felt_to_uint384
from cairo_ec.circuits.ec_ops_compiled import assert_is_on_curve

// @notice Try to get the point from x.
// @return y The y point such that (x, y) is on the curve if success is 1, otherwise (g*h, y) is on the curve
// @return is_on_curve 1 if the point is on the curve, 0 otherwise
// @dev g is the generator point and h is the hash of the message
func try_get_point_from_x{
range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*
}(x: UInt384, v: felt, a: UInt384*, b: UInt384*, g: UInt384*, p: UInt384*) -> (
y: UInt384, is_on_curve: felt
}(x: UInt384*, v: felt, a: UInt384*, b: UInt384*, g: UInt384*, p: UInt384*) -> (
y: UInt384*, is_on_curve: felt
) {
alloc_locals;
let add_mod_n = 5;
let (add_offsets_ptr) = get_label_location(add_offsets_ptr_loc);
let mul_mod_n = 7;
let (mul_offsets_ptr) = get_label_location(mul_offsets_ptr_loc);

local is_on_curve: felt;
let (__fp__, __pc__) = get_fp_and_pc();
local is_on_curve: UInt384;
local y_try: UInt384;
%{ compute_y_from_x_hint %}

assert 0 = is_on_curve * (1 - is_on_curve); // assert it's a bool
let input: UInt384* = cast(range_check96_ptr, UInt384*);
assert input[0] = UInt384(1, 0, 0, 0); // constant
assert input[1] = UInt384(0, 0, 0, 0); // constant
assert input[2] = x;
assert input[3] = [a];
assert input[4] = [b];
assert input[5] = [g];
assert input[6] = y_try;
assert input[7] = UInt384(is_on_curve, 0, 0, 0);

assert add_mod_ptr[0] = ModBuiltin(
p=[p], values_ptr=input, offsets_ptr=add_offsets_ptr, n=add_mod_n
);
assert mul_mod_ptr[0] = ModBuiltin(
p=[p], values_ptr=input, offsets_ptr=mul_offsets_ptr, n=mul_mod_n
);
assert_is_on_curve(x=x, y=&y_try, a=a, b=b, g=g, is_on_curve=&is_on_curve, p=p);
assert is_on_curve.d3 = 0;
assert is_on_curve.d2 = 0;
assert is_on_curve.d1 = 0;
// TODO: Add a check for v
zmalatrax marked this conversation as resolved.
Show resolved Hide resolved

%{
from starkware.cairo.lang.builtins.modulo.mod_builtin_runner import ModBuiltinRunner
assert builtin_runners["add_mod_builtin"].instance_def.batch_size == 1
assert builtin_runners["mul_mod_builtin"].instance_def.batch_size == 1

ModBuiltinRunner.fill_memory(
memory=memory,
add_mod=(ids.add_mod_ptr.address_, builtin_runners["add_mod_builtin"], ids.add_mod_n),
mul_mod=(ids.mul_mod_ptr.address_, builtin_runners["mul_mod_builtin"], ids.mul_mod_n),
)
%}

let add_mod_ptr = &add_mod_ptr[add_mod_n];
let mul_mod_ptr = &mul_mod_ptr[mul_mod_n];
let range_check96_ptr = range_check96_ptr + 76; // 72 is the last start index in the offset_ptr array

return (y=y_try, is_on_curve=is_on_curve);

add_offsets_ptr_loc:
dw 40; // ax
dw 16; // b
dw 44; // ax + b
dw 36; // x^3
dw 44; // ax + b
dw 48; // x^3 + ax + b (:= rhs)
dw 28; // is_on_curve
dw 60; // 1 - is_on_curve
dw 0; // 1
dw 56; // is_on_curve * rhs
dw 64; // (1 - is_on_curve) * g * rhs
dw 68; // is_on_curve * rhs + (1-is_on_curve) * g * rhs
dw 4; // 0
dw 72; // y_try^2
dw 68; // is_on_curve * rhs + (1-is_on_curve) * g * rhs

mul_offsets_ptr_loc:
dw 8; // x
dw 8; // x
dw 32; // x^2
dw 8; // x
dw 32; // x^2
dw 36; // x^3
dw 12; // a
dw 8; // x
dw 40; // ax
dw 20; // g
dw 48; // rhs
dw 52; // g * rhs
dw 28; // is_on_curve
dw 48; // rhs
dw 56; // is_on_curve * rhs
dw 60; // 1 - is_on_curve
dw 52; // g * rhs
dw 64; // (1 - is_on_curve) * g * rhs
dw 24; // y_try
dw 24; // y_try
dw 72; // y_try^2
return (y=&y_try, is_on_curve=is_on_curve.d0);
}

// @notice Get a random point from x
Expand All @@ -108,12 +38,13 @@ func get_random_point{
poseidon_ptr: PoseidonBuiltin*,
}(seed: felt, a: UInt384*, b: UInt384*, g: UInt384*, p: UInt384*) -> G1Point {
alloc_locals;
let (__fp__, __pc__) = get_fp_and_pc();
let x_384 = felt_to_uint384(seed);

let (y, is_on_curve) = try_get_point_from_x(x=x_384, v=0, a=a, b=b, g=g, p=p);
tempvar x = new x_384;
let (y, is_on_curve) = try_get_point_from_x(x=x, v=0, a=a, b=b, g=g, p=p);

if (is_on_curve != 0) {
let point = G1Point(x=x_384, y=y);
let point = G1Point(x=x_384, y=[y]);
return point;
}

Expand Down
3 changes: 2 additions & 1 deletion python/cairo-ec/tests/circuits/test_circuits.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ from cairo_ec.circuits.mod_ops_compiled import (
inv as inv_compiled,
assert_is_quad_residue as assert_is_quad_residue_compiled,
)
from cairo_ec.circuits.ec_ops import ec_add, ec_double
from cairo_ec.circuits.ec_ops import ec_add, ec_double, assert_is_on_curve
from cairo_ec.circuits.ec_ops_compiled import (
ec_add as ec_add_compiled,
ec_double as ec_double_compiled,
assert_is_on_curve as assert_is_on_curve_compiled,
)

func test__circuit{range_check96_ptr: felt*, add_mod_ptr: ModBuiltin*, mul_mod_ptr: ModBuiltin*}(
Expand Down
Loading