Skip to content

Commit

Permalink
Merge pull request ethereum#87 from zama-ai/louis-test-types
Browse files Browse the repository at this point in the history
feat(tests): add tests for all types
  • Loading branch information
tremblaythibaultl authored May 2, 2023
2 parents 170373b + 98b1daf commit 0c50eaa
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 8 deletions.
1 change: 1 addition & 0 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ func (ct *tfheCiphertext) encrypt(value uint64, t fheUintType) {
case FheUint32:
ct.setPtr(C.client_key_encrypt_fhe_uint32(cks, C.uint(value)))
}
ct.fheUintType = t
ct.value = &value
}

Expand Down
244 changes: 236 additions & 8 deletions core/vm/tfhe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
// TODO: Don't rely on global keys that are loaded from disk in init(). Instead,
// generate keys on demand in the test.

func TestTfheCksEncryptDecrypt(t *testing.T) {
func TestTfheCksEncryptDecrypt8(t *testing.T) {
val := uint64(2)
ct := new(tfheCiphertext)
ct.encrypt(val, FheUint8)
Expand All @@ -34,7 +34,7 @@ func TestTfheCksEncryptDecrypt(t *testing.T) {
}
}

func TestTfheSerializeDeserialize(t *testing.T) {
func TestTfheSerializeDeserialize8(t *testing.T) {
val := uint64(2)
ctBytes := clientKeyEncrypt(val, FheUint8)
ct := new(tfheCiphertext)
Expand All @@ -48,15 +48,15 @@ func TestTfheSerializeDeserialize(t *testing.T) {
}
}

func TestTfheDeserializeFailure(t *testing.T) {
func TestTfheDeserializeFailure8(t *testing.T) {
ct := new(tfheCiphertext)
err := ct.deserialize(make([]byte, 10), FheUint8)
if err == nil {
t.Fatalf("deserialization must have failed")
}
}

func TestTfheAdd(t *testing.T) {
func TestTfheAdd8(t *testing.T) {
a := uint64(1)
b := uint64(1)
expected := uint64(2)
Expand All @@ -71,7 +71,7 @@ func TestTfheAdd(t *testing.T) {
}
}

func TestTfheSub(t *testing.T) {
func TestTfheSub8(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(1)
Expand All @@ -86,7 +86,7 @@ func TestTfheSub(t *testing.T) {
}
}

func TestTfheMul(t *testing.T) {
func TestTfheMul8(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(2)
Expand All @@ -101,7 +101,7 @@ func TestTfheMul(t *testing.T) {
}
}

func TestTfheLte(t *testing.T) {
func TestTfheLte8(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
Expand All @@ -119,7 +119,7 @@ func TestTfheLte(t *testing.T) {
t.Fatalf("%d != %d", 0, res2)
}
}
func TestTfheLt(t *testing.T) {
func TestTfheLt8(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
Expand All @@ -138,6 +138,234 @@ func TestTfheLt(t *testing.T) {
}
}

func TestTfheCksEncryptDecrypt16(t *testing.T) {
val := uint64(2)
ct := new(tfheCiphertext)
ct.encrypt(val, FheUint16)
res := ct.decrypt()
if res != val {
t.Fatalf("%d != %d", val, res)
}
}

func TestTfheSerializeDeserialize16(t *testing.T) {
val := uint64(2)
ctBytes := clientKeyEncrypt(val, FheUint16)
ct := new(tfheCiphertext)
err := ct.deserialize(ctBytes, FheUint16)
if err != nil {
t.Fatalf("deserialization failed")
}
serialized := ct.serialize()
if !bytes.Equal(serialized, ctBytes) {
t.Fatalf("serialization failed")
}
}

func TestTfheDeserializeFailure16(t *testing.T) {
ct := new(tfheCiphertext)
err := ct.deserialize(make([]byte, 10), FheUint16)
if err == nil {
t.Fatalf("deserialization must have failed")
}
}

func TestTfheAdd16(t *testing.T) {
a := uint64(1)
b := uint64(1)
expected := uint64(2)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint16)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint16)
ctRes, _ := ctA.add(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheSub16(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint16)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint16)
ctRes, _ := ctA.sub(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheMul16(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(2)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint16)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint16)
ctRes, _ := ctA.mul(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheLte16(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint16)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint16)
ctRes1, _ := ctA.lte(ctB)
ctRes2, _ := ctB.lte(ctA)
res1 := ctRes1.decrypt()
res2 := ctRes2.decrypt()
if res1 != 0 {
t.Fatalf("%d != %d", 0, res1)
}
if res2 != 1 {
t.Fatalf("%d != %d", 0, res2)
}
}
func TestTfheLt16(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint16)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint16)
ctRes1, _ := ctA.lte(ctB)
ctRes2, _ := ctB.lte(ctA)
res1 := ctRes1.decrypt()
res2 := ctRes2.decrypt()
if res1 != 0 {
t.Fatalf("%d != %d", 0, res1)
}
if res2 != 1 {
t.Fatalf("%d != %d", 0, res2)
}
}

func TestTfheCksEncryptDecrypt32(t *testing.T) {
val := uint64(2)
ct := new(tfheCiphertext)
ct.encrypt(val, FheUint32)
res := ct.decrypt()
if res != val {
t.Fatalf("%d != %d", val, res)
}
}

func TestTfheSerializeDeserialize32(t *testing.T) {
val := uint64(2)
ctBytes := clientKeyEncrypt(val, FheUint32)
ct := new(tfheCiphertext)
err := ct.deserialize(ctBytes, FheUint32)
if err != nil {
t.Fatalf("deserialization failed")
}
serialized := ct.serialize()
if !bytes.Equal(serialized, ctBytes) {
t.Fatalf("serialization failed")
}
}

func TestTfheDeserializeFailure32(t *testing.T) {
ct := new(tfheCiphertext)
err := ct.deserialize(make([]byte, 10), FheUint32)
if err == nil {
t.Fatalf("deserialization must have failed")
}
}

func TestTfheAdd32(t *testing.T) {
a := uint64(1)
b := uint64(1)
expected := uint64(2)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint32)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint32)
ctRes, _ := ctA.add(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheSub32(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint32)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint32)
ctRes, _ := ctA.sub(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheMul32(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(2)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint32)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint32)
ctRes, _ := ctA.mul(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheLte32(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint32)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint32)
ctRes1, _ := ctA.lte(ctB)
ctRes2, _ := ctB.lte(ctA)
res1 := ctRes1.decrypt()
res2 := ctRes2.decrypt()
if res1 != 0 {
t.Fatalf("%d != %d", 0, res1)
}
if res2 != 1 {
t.Fatalf("%d != %d", 0, res2)
}
}
func TestTfheLt32(t *testing.T) {
a := uint64(2)
b := uint64(1)
ctA := new(tfheCiphertext)
ctA.encrypt(a, FheUint32)
ctB := new(tfheCiphertext)
ctB.encrypt(b, FheUint32)
ctRes1, _ := ctA.lte(ctB)
ctRes2, _ := ctB.lte(ctA)
res1 := ctRes1.decrypt()
res2 := ctRes2.decrypt()
if res1 != 0 {
t.Fatalf("%d != %d", 0, res1)
}
if res2 != 1 {
t.Fatalf("%d != %d", 0, res2)
}
}

// func TestTfheTrivialEncryptDecrypt(t *testing.T) {
// val := uint64(2)
// ct := new(tfheCiphertext)
Expand Down

0 comments on commit 0c50eaa

Please sign in to comment.