diff --git a/dh/curve4q/curve4Q.go b/dh/curve4q/curve4Q.go index 53ab8b776..1479ce9c1 100644 --- a/dh/curve4q/curve4Q.go +++ b/dh/curve4q/curve4Q.go @@ -20,8 +20,10 @@ func KeyGen(public, secret *Key) { func Shared(shared, secret, public *Key) bool { var P, Q fourq.Point ok := P.Unmarshal((*[Size]byte)(public)) + if !ok { + return false + } Q.ScalarMult((*[Size]byte)(secret), &P) Q.Marshal((*[Size]byte)(shared)) - ok = ok && Q.IsOnCurve() - return ok + return !Q.IsIdentity() && Q.IsOnCurve() } diff --git a/dh/curve4q/curve4Q_test.go b/dh/curve4q/curve4Q_test.go index 0f3cea154..b40434a45 100644 --- a/dh/curve4q/curve4Q_test.go +++ b/dh/curve4q/curve4Q_test.go @@ -7,6 +7,7 @@ import ( "io" "testing" + "github.com/cloudflare/circl/ecc/fourq" "github.com/cloudflare/circl/internal/test" ) @@ -37,6 +38,60 @@ func TestDH(t *testing.T) { } } +func TestDHLowOrder(t *testing.T) { + var secretAlice, validPublicAlice, invalidPublicAlice, sharedAlice Key + var secretBob, publicBob, sharedBob Key + + t.Run("zeroPoint", func(t *testing.T) { + testTimes := 1 << 10 + + for i := 0; i < testTimes; i++ { + _, _ = rand.Read(secretAlice[:]) + _, _ = rand.Read(secretBob[:]) + + KeyGen(&validPublicAlice, &secretAlice) + KeyGen(&publicBob, &secretBob) + + zeroPoint := fourq.Point{} + zeroPoint.SetIdentity() + zeroPoint.Marshal((*[Size]byte)(&invalidPublicAlice)) + + ok := Shared(&sharedAlice, &secretAlice, &publicBob) + test.CheckOk(ok, "shared must not fail", t) + + ok = Shared(&sharedBob, &secretBob, &validPublicAlice) + test.CheckOk(ok, "shared must not fail", t) + + invalid := Shared(&sharedBob, &secretBob, &invalidPublicAlice) + test.CheckOk(!invalid, "shared must fail", t) + } + }) + + t.Run("lowOrderPoint", func(t *testing.T) { + KeyGen(&validPublicAlice, &secretAlice) + KeyGen(&publicBob, &secretBob) + + // Point of order 56 + lowOrderPoint := fourq.Point{ + X: fourq.Fq{ + fourq.Fp{0xc0, 0xe5, 0x21, 0x04, 0xaa, 0xe1, 0x93, 0xd8, 0x9b, 0x50, 0x42, 0x54, 0xd6, 0x46, 0x86, 0x74}, + fourq.Fp{0x21, 0x25, 0x4d, 0x9a, 0xda, 0x8f, 0xad, 0x28, 0xa2, 0x3d, 0xfd, 0x02, 0x13, 0xea, 0xd2, 0x56}, + }, + Y: fourq.Fq{ + fourq.Fp{0xaf, 0x71, 0xe4, 0x3b, 0x22, 0x21, 0x41, 0xef, 0x12, 0xba, 0x67, 0x02, 0x57, 0x1, 0xe5, 0x58}, + fourq.Fp{0x0e, 0x1a, 0xf5, 0xe5, 0xb8, 0x24, 0x9c, 0xe0, 0xed, 0xc3, 0xc4, 0x69, 0x7, 0x32, 0x8e, 0x2c}, + }, + } + + ok := lowOrderPoint.IsOnCurve() + test.CheckOk(ok, "point is on curve", t) + + lowOrderPoint.Marshal((*[Size]byte)(&invalidPublicAlice)) + invalid := Shared(&sharedBob, &secretBob, &invalidPublicAlice) + test.CheckOk(!invalid, "shared must fail", t) + }) +} + func BenchmarkDH(b *testing.B) { var secret, public, shared Key _, _ = rand.Read(secret[:]) diff --git a/ecc/fourq/curve.go b/ecc/fourq/curve.go index 5ba629552..cfbf12f78 100644 --- a/ecc/fourq/curve.go +++ b/ecc/fourq/curve.go @@ -12,6 +12,8 @@ const Size = 32 // Point represents an affine point of the curve. The identity is (0,1). type Point struct{ X, Y Fq } +func (P Point) String() string { return "(x: " + P.X.String() + ", y: " + P.Y.String() + ")" } + // CurveParams contains the parameters of the elliptic curve. type CurveParams struct { Name string // The canonical name of the curve. diff --git a/ecc/fourq/point.go b/ecc/fourq/point.go index a0a73392d..94974a454 100644 --- a/ecc/fourq/point.go +++ b/ecc/fourq/point.go @@ -249,59 +249,78 @@ func (P *Point) Marshal(out *[Size]byte) { // Unmarshal retrieves a point P from the input buffer. On success, returns true. func (P *Point) Unmarshal(in *[Size]byte) bool { + var Q Point s := in[Size-1] >> 7 in[Size-1] &= 0x7F - if ok := P.Y.fromBytes(in[:]); !ok { - return ok - } + ok := Q.Y.fromBytes(in[:]) in[Size-1] |= s << 7 + if !ok { + return false + } t0, t1, one := &Fq{}, &Fq{}, &Fq{} one.setOne() - fqSqr(t0, &P.Y) // t0 = y^2 + fqSqr(t0, &Q.Y) // t0 = y^2 fqMul(t1, t0, ¶mD) // t1 = d*y^2 fqSub(t0, t0, one) // t0 = y^2 - 1 fqAdd(t1, t1, one) // t1 = d*y^2 + 1 - fqSqrt(&P.X, t0, t1, 1-2*int(s)) // x = sqrt(t0/t1) + fqSqrt(&Q.X, t0, t1, 1-2*int(s)) // x = sqrt(t0/t1) - if !P.IsOnCurve() { - fpNeg(&P.X[1], &P.X[1]) + if !Q.IsOnCurve() { + fpNeg(&Q.X[1], &Q.X[1]) } + if !Q.IsOnCurve() { + return false + } + + *P = Q return true } func (P *pointR1) IsOnCurve() bool { t0, lhs, rhs := &Fq{}, &Fq{}, &Fq{} + // Check z != 0 + eq0 := !P.Z.isZero() + // Check Eq 1: -X^2 + Y^2 == Z^2 + dT^2 fqAdd(t0, &P.Y, &P.X) // t0 = y + x fqSub(lhs, &P.Y, &P.X) // lhs = y - x fqMul(lhs, lhs, t0) // lhs = y^2 - x^2 - fqMul(rhs, &P.X, &P.Y) // rhs = xy - fqSqr(rhs, rhs) // rhs = x^2y^2 - fqMul(rhs, rhs, ¶mD) // rhs = dx^2y^2 - t0.setOne() // t0 = 1 - fqAdd(rhs, rhs, t0) // rhs = 1 + dx^2y^2 - fqSub(t0, lhs, rhs) // t0 = -x^2 + y^2 - (1 + dx^2y^2) - return t0.isZero() + fqMul(rhs, &P.Ta, &P.Tb) // rhs = T = Ta * Tb + fqSqr(rhs, rhs) // rhs = T^2 + fqMul(rhs, rhs, ¶mD) // rhs = dT^2 + fqSqr(t0, &P.Z) // t0 = Z^2 + fqAdd(rhs, rhs, t0) // rhs = Z^2 + dT^2 + fqSub(t0, lhs, rhs) // t0 = (-X^2 + Y^2) - (Z^2 + dT^2) + eq1 := t0.isZero() + + // Check Eq 2: (Ta*Tb)*Z == X*Y + fqMul(lhs, &P.Ta, &P.Tb) // lhs = Ta*Tb = T + fqMul(lhs, lhs, &P.Z) // lhs = T * Z + fqMul(rhs, &P.X, &P.Y) // rhs = X * Y + fqSub(t0, lhs, rhs) // t0 = Ta*Tb*Z - X*Y + eq2 := t0.isZero() + + return eq0 && eq1 && eq2 } func (P *pointR1) isEqual(Q *pointR1) bool { l, r := &Fq{}, &Fq{} - fqMul(l, &P.X, &Q.Z) - fqMul(r, &Q.X, &P.Z) - fqSub(l, l, r) + fqMul(l, &P.X, &Q.Z) // l = X1*Z2 + fqMul(r, &Q.X, &P.Z) // r = X2*Z1 + fqSub(l, l, r) // l = l-r b := l.isZero() - fqMul(l, &P.Y, &Q.Z) - fqMul(r, &Q.Y, &P.Z) - fqSub(l, l, r) + fqMul(l, &P.Y, &Q.Z) // l = Y1*Z2 + fqMul(r, &Q.Y, &P.Z) // r = Y2*Z1 + fqSub(l, l, r) // l = l-r b = b && l.isZero() - fqMul(l, &P.Ta, &P.Tb) - fqMul(l, l, &Q.Z) - fqMul(r, &Q.Ta, &Q.Tb) - fqMul(r, r, &P.Z) - fqSub(l, l, r) + fqMul(l, &P.Ta, &P.Tb) // l = T1 = Ta1*Tb1 + fqMul(l, l, &Q.Z) // l = T1*Z2 + fqMul(r, &Q.Ta, &Q.Tb) // r = T2 = Ta2*Tb2 + fqMul(r, r, &P.Z) // r = T2*Z1 + fqSub(l, l, r) // l = l-r b = b && l.isZero() - return b + return b && !P.Z.isZero() && !Q.Z.isZero() } func (P *pointR1) ClearCofactor() { diff --git a/ecc/fourq/point_test.go b/ecc/fourq/point_test.go index 7bbfd2f97..93be63877 100644 --- a/ecc/fourq/point_test.go +++ b/ecc/fourq/point_test.go @@ -15,6 +15,61 @@ func (P *pointR1) random() { P.ScalarBaseMult(&k) } +func TestPoint(t *testing.T) { + const testTimes = 1 << 10 + t.Run("IsOnCurve(ok)", func(t *testing.T) { + var gen Point + var goodGen pointR1 + gen.SetGenerator() + gen.toR1(&goodGen) + test.CheckOk(goodGen.IsOnCurve(), "valid point should pass", t) + }) + + t.Run("IsOnCurve(zero)", func(t *testing.T) { + var allZeros pointR1 + test.CheckOk(!allZeros.IsOnCurve(), "invalid point should be detected", t) + }) + + t.Run("IsOnCurve(bad)", func(t *testing.T) { + var badGen pointR1 + badGen.X = genX + badGen.Y = genY + test.CheckOk(!badGen.IsOnCurve(), "invalid point should be detected", t) + }) + + t.Run("IsEqual", func(t *testing.T) { + var badGen pointR1 + badGen.X = genX + badGen.Y = genY + var gen Point + var goodGen pointR1 + gen.SetGenerator() + gen.toR1(&goodGen) + test.CheckOk(!badGen.isEqual(&goodGen), "invalid point shouldn't match generator", t) + test.CheckOk(!goodGen.isEqual(&badGen), "invalid point shouldn't match generator", t) + test.CheckOk(goodGen.isEqual(&goodGen), "valid point should match generator", t) + test.CheckOk(!badGen.isEqual(&badGen), "invalid point shouldn't match anything", t) + }) + + t.Run("isEqual(fail-w/random)", func(t *testing.T) { + var badG pointR1 + badG.X = genX + badG.Y = genY + test.CheckOk(!badG.IsOnCurve(), "invalid point should be detected", t) + + var k [Size]byte + var got, want pointR1 + for i := 0; i < testTimes; i++ { + _, _ = rand.Read(k[:]) + got.ScalarMult(&k, &badG) + want.random() + if got.isEqual(&want) { + test.ReportError(t, got, want, k) + } + } + }) +} + func TestPointAddition(t *testing.T) { const testTimes = 1 << 10 var P, Q pointR1 @@ -55,6 +110,7 @@ func TestOddMultiples(t *testing.T) { Q.add(&Tab[j]) } // R = (2^6)P == 64P + R = P for j := 0; j < 6; j++ { R.double() } @@ -68,7 +124,7 @@ func TestOddMultiples(t *testing.T) { func TestScalarMult(t *testing.T) { const testTimes = 1 << 10 - var P, Q, G pointR1 + var P, Q pointR1 var k [Size]byte t.Run("0P=0", func(t *testing.T) { @@ -108,11 +164,13 @@ func TestScalarMult(t *testing.T) { } }) t.Run("mult", func(t *testing.T) { - G.X = genX - G.Y = genY + var G Point + G.SetGenerator() + var gen pointR1 + G.toR1(&gen) for i := 0; i < testTimes; i++ { _, _ = rand.Read(k[:]) - P.ScalarMult(&k, &G) + P.ScalarMult(&k, &gen) Q.ScalarBaseMult(&k) got := Q.isEqual(&P) want := true @@ -121,6 +179,33 @@ func TestScalarMult(t *testing.T) { } } }) + t.Run("mult-non_curve_point_issue", func(t *testing.T) { + for i := 0; i < testTimes; i++ { + _, _ = rand.Read(k[:]) + Q.random() + P.ScalarMult(&k, &Q) + if !P.IsOnCurve() { + t.Fatalf("Point is not on curve: %X\n", P) + } + } + }) + t.Run("unmarshal-faulty-point", func(t *testing.T) { + // This test demonstrates that it is possible to find points which are unmarshalled + // successfully, but are not on the curve. + var marshalledPoint [Size]byte + for i := 0; i < testTimes; i++ { + _, _ = rand.Read(marshalledPoint[:]) + unmarshalledP := Point{} + ok := unmarshalledP.Unmarshal(&marshalledPoint) + isOnCurve := unmarshalledP.IsOnCurve() + switch true { + case ok && !isOnCurve: + t.Fatalf("unmarshal ok, but not on curve: %v\n", unmarshalledP) + case !ok && isOnCurve: + t.Fatalf("unmarshal failed with a point on curve: %v\n", unmarshalledP) + } + } + }) } func TestScalar(t *testing.T) { diff --git a/ecc/goldilocks/curve.go b/ecc/goldilocks/curve.go index 5a939100d..1f165141a 100644 --- a/ecc/goldilocks/curve.go +++ b/ecc/goldilocks/curve.go @@ -18,6 +18,9 @@ func (Curve) Identity() *Point { func (Curve) IsOnCurve(P *Point) bool { x2, y2, t, t2, z2 := &fp.Elt{}, &fp.Elt{}, &fp.Elt{}, &fp.Elt{}, &fp.Elt{} rhs, lhs := &fp.Elt{}, &fp.Elt{} + // Check z != 0 + eq0 := !fp.IsZero(&P.z) + fp.Mul(t, &P.ta, &P.tb) // t = ta*tb fp.Sqr(x2, &P.x) // x^2 fp.Sqr(y2, &P.y) // y^2 @@ -27,13 +30,14 @@ func (Curve) IsOnCurve(P *Point) bool { fp.Mul(rhs, t2, ¶mD) // dt^2 fp.Add(rhs, rhs, z2) // z^2 + dt^2 fp.Sub(lhs, lhs, rhs) // x^2 + y^2 - (z^2 + dt^2) - eq0 := fp.IsZero(lhs) + eq1 := fp.IsZero(lhs) fp.Mul(lhs, &P.x, &P.y) // xy fp.Mul(rhs, t, &P.z) // tz fp.Sub(lhs, lhs, rhs) // xy - tz - eq1 := fp.IsZero(lhs) - return eq0 && eq1 + eq2 := fp.IsZero(lhs) + + return eq0 && eq1 && eq2 } // Generator returns the generator point. diff --git a/ecc/goldilocks/point_test.go b/ecc/goldilocks/point_test.go index a25bb5e0f..d320054e5 100644 --- a/ecc/goldilocks/point_test.go +++ b/ecc/goldilocks/point_test.go @@ -15,6 +15,19 @@ func randomPoint() *goldilocks.Point { return goldilocks.Curve{}.ScalarBaseMult(&k) } +func TestPoint(t *testing.T) { + c := goldilocks.Curve{} + t.Run("IsOnCurve(ok)", func(t *testing.T) { + goodGen := c.Generator() + test.CheckOk(c.IsOnCurve(goodGen), "valid point should pass", t) + }) + + t.Run("IsOnCurve(zero)", func(t *testing.T) { + var allZeros goldilocks.Point + test.CheckOk(!c.IsOnCurve(&allZeros), "invalid point should be detected", t) + }) +} + func TestPointAdd(t *testing.T) { const testTimes = 1 << 10 var e goldilocks.Curve diff --git a/sign/ed25519/point.go b/sign/ed25519/point.go index 374a69503..d1c3b146b 100644 --- a/sign/ed25519/point.go +++ b/sign/ed25519/point.go @@ -164,7 +164,7 @@ func (P *pointR1) isEqual(Q *pointR1) bool { fp.Mul(r, r, &P.z) fp.Sub(l, l, r) b = b && fp.IsZero(l) - return b + return b && !fp.IsZero(&P.z) && !fp.IsZero(&Q.z) } func (P *pointR3) neg() { diff --git a/sign/ed25519/point_test.go b/sign/ed25519/point_test.go index e2317950b..c5da133e6 100644 --- a/sign/ed25519/point_test.go +++ b/sign/ed25519/point_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/cloudflare/circl/internal/test" + "github.com/cloudflare/circl/math/fp25519" ) func randomPoint(P *pointR1) { @@ -17,6 +18,17 @@ func randomPoint(P *pointR1) { func TestPoint(t *testing.T) { const testTimes = 1 << 10 + t.Run("isEqual", func(t *testing.T) { + var valid, invalid pointR1 + randomPoint(&valid) + randomPoint(&invalid) + invalid.z = fp25519.Elt{} + test.CheckOk(!valid.isEqual(&invalid), "valid point shouldn't match invalid point", t) + test.CheckOk(!invalid.isEqual(&valid), "invalid point shouldn't match valid point", t) + test.CheckOk(valid.isEqual(&valid), "valid point should match valid point", t) + test.CheckOk(!invalid.isEqual(&invalid), "invalid point shouldn't match anything", t) + }) + t.Run("add", func(t *testing.T) { var P pointR1 var Q pointR1