From 813c6215948c37b5a829c601042b23143d2909c2 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Sun, 12 Feb 2023 18:07:37 -0600 Subject: [PATCH] feat: add Vector support to ecc marshal encoder (#336) * fix: vector marshal methods on pointer type * feat: support for Vector encode / decode in Marshal --- ecc/bls12-377/fp/vector.go | 10 +- ecc/bls12-377/fr/vector.go | 10 +- ecc/bls12-377/marshal.go | 163 ++++++------------ ecc/bls12-377/marshal_test.go | 12 +- ecc/bls12-378/fp/vector.go | 10 +- ecc/bls12-378/fr/vector.go | 10 +- ecc/bls12-378/marshal.go | 163 ++++++------------ ecc/bls12-378/marshal_test.go | 12 +- ecc/bls12-381/fp/vector.go | 10 +- ecc/bls12-381/fr/vector.go | 10 +- ecc/bls12-381/marshal.go | 163 ++++++------------ ecc/bls12-381/marshal_test.go | 12 +- ecc/bls24-315/fp/vector.go | 10 +- ecc/bls24-315/fr/vector.go | 10 +- ecc/bls24-315/marshal.go | 163 ++++++------------ ecc/bls24-315/marshal_test.go | 12 +- ecc/bls24-317/fp/vector.go | 10 +- ecc/bls24-317/fr/vector.go | 10 +- ecc/bls24-317/marshal.go | 163 ++++++------------ ecc/bls24-317/marshal_test.go | 12 +- ecc/bn254/fp/vector.go | 10 +- ecc/bn254/fr/vector.go | 10 +- ecc/bn254/marshal.go | 163 ++++++------------ ecc/bn254/marshal_test.go | 12 +- ecc/bw6-633/fp/vector.go | 10 +- ecc/bw6-633/fr/vector.go | 10 +- ecc/bw6-633/marshal.go | 163 ++++++------------ ecc/bw6-633/marshal_test.go | 12 +- ecc/bw6-756/fp/vector.go | 10 +- ecc/bw6-756/fr/vector.go | 10 +- ecc/bw6-756/marshal.go | 163 ++++++------------ ecc/bw6-756/marshal_test.go | 12 +- ecc/bw6-761/fp/vector.go | 10 +- ecc/bw6-761/fr/vector.go | 10 +- ecc/bw6-761/marshal.go | 163 ++++++------------ ecc/bw6-761/marshal_test.go | 12 +- ecc/secp256k1/fp/vector.go | 10 +- ecc/secp256k1/fr/vector.go | 10 +- ecc/stark-curve/fp/vector.go | 10 +- ecc/stark-curve/fr/vector.go | 10 +- .../internal/templates/element/vector.go | 10 +- field/goldilocks/vector.go | 10 +- .../generator/ecc/template/marshal.go.tmpl | 110 ++++-------- .../ecc/template/tests/marshal.go.tmpl | 12 +- 44 files changed, 769 insertions(+), 1168 deletions(-) diff --git a/ecc/bls12-377/fp/vector.go b/ecc/bls12-377/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bls12-377/fp/vector.go +++ b/ecc/bls12-377/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-377/fr/vector.go b/ecc/bls12-377/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bls12-377/fr/vector.go +++ b/ecc/bls12-377/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-377/marshal.go b/ecc/bls12-377/marshal.go index b39f62428..b71e1c2dc 100644 --- a/ecc/bls12-377/marshal.go +++ b/ecc/bls12-377/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bls12-377/marshal_test.go b/ecc/bls12-377/marshal_test.go index 4210f29be..c1b60a0d2 100644 --- a/ecc/bls12-377/marshal_test.go +++ b/ecc/bls12-377/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bls12-378/fp/vector.go b/ecc/bls12-378/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bls12-378/fp/vector.go +++ b/ecc/bls12-378/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-378/fr/vector.go b/ecc/bls12-378/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bls12-378/fr/vector.go +++ b/ecc/bls12-378/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-378/marshal.go b/ecc/bls12-378/marshal.go index 1120fa850..2b12b7bbb 100644 --- a/ecc/bls12-378/marshal.go +++ b/ecc/bls12-378/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bls12-378/marshal_test.go b/ecc/bls12-378/marshal_test.go index 88d895280..ac3963924 100644 --- a/ecc/bls12-378/marshal_test.go +++ b/ecc/bls12-378/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bls12-381/fp/vector.go b/ecc/bls12-381/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bls12-381/fp/vector.go +++ b/ecc/bls12-381/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-381/fr/vector.go b/ecc/bls12-381/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bls12-381/fr/vector.go +++ b/ecc/bls12-381/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls12-381/marshal.go b/ecc/bls12-381/marshal.go index 2b0e584f4..fad5940db 100644 --- a/ecc/bls12-381/marshal.go +++ b/ecc/bls12-381/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bls12-381/marshal_test.go b/ecc/bls12-381/marshal_test.go index f2c414206..d21e2910f 100644 --- a/ecc/bls12-381/marshal_test.go +++ b/ecc/bls12-381/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bls24-315/fp/vector.go b/ecc/bls24-315/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bls24-315/fp/vector.go +++ b/ecc/bls24-315/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls24-315/fr/vector.go b/ecc/bls24-315/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bls24-315/fr/vector.go +++ b/ecc/bls24-315/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls24-315/marshal.go b/ecc/bls24-315/marshal.go index 944af6694..79fbab810 100644 --- a/ecc/bls24-315/marshal.go +++ b/ecc/bls24-315/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bls24-315/marshal_test.go b/ecc/bls24-315/marshal_test.go index 3690d497c..a18756f96 100644 --- a/ecc/bls24-315/marshal_test.go +++ b/ecc/bls24-315/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bls24-317/fp/vector.go b/ecc/bls24-317/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bls24-317/fp/vector.go +++ b/ecc/bls24-317/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls24-317/fr/vector.go b/ecc/bls24-317/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bls24-317/fr/vector.go +++ b/ecc/bls24-317/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bls24-317/marshal.go b/ecc/bls24-317/marshal.go index ddd960555..8f50f8e17 100644 --- a/ecc/bls24-317/marshal.go +++ b/ecc/bls24-317/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bls24-317/marshal_test.go b/ecc/bls24-317/marshal_test.go index c1e0c7f3a..97874cb5b 100644 --- a/ecc/bls24-317/marshal_test.go +++ b/ecc/bls24-317/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bn254/fp/vector.go b/ecc/bn254/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bn254/fp/vector.go +++ b/ecc/bn254/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bn254/fr/vector.go b/ecc/bn254/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bn254/fr/vector.go +++ b/ecc/bn254/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bn254/marshal.go b/ecc/bn254/marshal.go index 4ae011cf1..f980f6f16 100644 --- a/ecc/bn254/marshal.go +++ b/ecc/bn254/marshal.go @@ -80,9 +80,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -105,46 +112,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -394,7 +367,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -416,41 +397,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -508,7 +470,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -530,41 +500,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bn254/marshal_test.go b/ecc/bn254/marshal_test.go index 9efb0365b..45504bdf5 100644 --- a/ecc/bn254/marshal_test.go +++ b/ecc/bn254/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bw6-633/fp/vector.go b/ecc/bw6-633/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bw6-633/fp/vector.go +++ b/ecc/bw6-633/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-633/fr/vector.go b/ecc/bw6-633/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bw6-633/fr/vector.go +++ b/ecc/bw6-633/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-633/marshal.go b/ecc/bw6-633/marshal.go index 88d800df2..78c1df948 100644 --- a/ecc/bw6-633/marshal.go +++ b/ecc/bw6-633/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bw6-633/marshal_test.go b/ecc/bw6-633/marshal_test.go index 93ff07df0..bd047f0d7 100644 --- a/ecc/bw6-633/marshal_test.go +++ b/ecc/bw6-633/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bw6-756/fp/vector.go b/ecc/bw6-756/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bw6-756/fp/vector.go +++ b/ecc/bw6-756/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-756/fr/vector.go b/ecc/bw6-756/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bw6-756/fr/vector.go +++ b/ecc/bw6-756/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-756/marshal.go b/ecc/bw6-756/marshal.go index 12283d885..8a3a01315 100644 --- a/ecc/bw6-756/marshal.go +++ b/ecc/bw6-756/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bw6-756/marshal_test.go b/ecc/bw6-756/marshal_test.go index 991e9cf1d..b17c560f4 100644 --- a/ecc/bw6-756/marshal_test.go +++ b/ecc/bw6-756/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/bw6-761/fp/vector.go b/ecc/bw6-761/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/bw6-761/fp/vector.go +++ b/ecc/bw6-761/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-761/fr/vector.go b/ecc/bw6-761/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/bw6-761/fr/vector.go +++ b/ecc/bw6-761/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/bw6-761/marshal.go b/ecc/bw6-761/marshal.go index 4def5de67..ad6160197 100644 --- a/ecc/bw6-761/marshal.go +++ b/ecc/bw6-761/marshal.go @@ -86,9 +86,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n += read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -111,46 +118,12 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n += read64 return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. @@ -400,7 +373,15 @@ func (enc *Encoder) encode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -422,41 +403,22 @@ func (enc *Encoder) encode(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) @@ -514,7 +476,15 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -536,41 +506,22 @@ func (enc *Encoder) encodeRaw(v interface{}) (err error) { written, err = enc.w.Write(buf[:]) enc.n += int64(written) return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/ecc/bw6-761/marshal_test.go b/ecc/bw6-761/marshal_test.go index 7e2554349..fe66a6741 100644 --- a/ecc/bw6-761/marshal_test.go +++ b/ecc/bw6-761/marshal_test.go @@ -21,6 +21,7 @@ import ( "io" "math/big" "math/rand" + "reflect" "testing" "github.com/leanovate/gopter" @@ -50,6 +51,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -64,12 +66,14 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -93,8 +97,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -131,6 +136,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") } diff --git a/ecc/secp256k1/fp/vector.go b/ecc/secp256k1/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/secp256k1/fp/vector.go +++ b/ecc/secp256k1/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/secp256k1/fr/vector.go b/ecc/secp256k1/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/secp256k1/fr/vector.go +++ b/ecc/secp256k1/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/stark-curve/fp/vector.go b/ecc/stark-curve/fp/vector.go index 0c11ec134..e75796169 100644 --- a/ecc/stark-curve/fp/vector.go +++ b/ecc/stark-curve/fp/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/ecc/stark-curve/fr/vector.go b/ecc/stark-curve/fr/vector.go index 6d23664e8..c85a3c134 100644 --- a/ecc/stark-curve/fr/vector.go +++ b/ecc/stark-curve/fr/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/field/generator/internal/templates/element/vector.go b/field/generator/internal/templates/element/vector.go index 02bcff3e8..7a09d0bf2 100644 --- a/field/generator/internal/templates/element/vector.go +++ b/field/generator/internal/templates/element/vector.go @@ -20,7 +20,7 @@ import ( type Vector []{{.ElementName}} // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -39,17 +39,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded {{.ElementName}}. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/field/goldilocks/vector.go b/field/goldilocks/vector.go index 5cc7e810b..213504039 100644 --- a/field/goldilocks/vector.go +++ b/field/goldilocks/vector.go @@ -35,7 +35,7 @@ import ( type Vector []Element // MarshalBinary implements encoding.BinaryMarshaler -func (vector Vector) MarshalBinary() (data []byte, err error) { +func (vector *Vector) MarshalBinary() (data []byte, err error) { var buf bytes.Buffer if _, err = vector.WriteTo(&buf); err != nil { @@ -53,17 +53,17 @@ func (vector *Vector) UnmarshalBinary(data []byte) error { // WriteTo implements io.WriterTo and writes a vector of big endian encoded Element. // Length of the vector is encoded as a uint32 on the first 4 bytes. -func (vector Vector) WriteTo(w io.Writer) (int64, error) { +func (vector *Vector) WriteTo(w io.Writer) (int64, error) { // encode slice length - if err := binary.Write(w, binary.BigEndian, uint32(len(vector))); err != nil { + if err := binary.Write(w, binary.BigEndian, uint32(len(*vector))); err != nil { return 0, err } n := int64(4) var buf [Bytes]byte - for i := 0; i < len(vector); i++ { - BigEndian.PutElement(&buf, vector[i]) + for i := 0; i < len(*vector); i++ { + BigEndian.PutElement(&buf, (*vector)[i]) m, err := w.Write(buf[:]) n += int64(m) if err != nil { diff --git a/internal/generator/ecc/template/marshal.go.tmpl b/internal/generator/ecc/template/marshal.go.tmpl index f3aa89e0d..4f31c9c3e 100644 --- a/internal/generator/ecc/template/marshal.go.tmpl +++ b/internal/generator/ecc/template/marshal.go.tmpl @@ -92,9 +92,16 @@ func (dec *Decoder) Decode(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap // in particular, careful attention must be given to usage of Bytes() method on Elements and Points - // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs + // that return an array (not a slice) of bytes. Using this is beneficial to minimize memory allocations // in very large (de)serialization upstream in gnark. - // (but detrimental to code lisibility here) + // (but detrimental to code readability here) + + var read64 int64 + if vf, ok := v.(io.ReaderFrom); ok { + read64, err = vf.ReadFrom(dec.r) + dec.n+=read64 + return + } var buf [SizeOfG2AffineUncompressed]byte var read int @@ -117,47 +124,13 @@ func (dec *Decoder) Decode(v interface{}) (err error) { err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fr.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { - return - } - } + read64, err = (*fr.Vector)(t).ReadFrom(dec.r) + dec.n+=read64 return case *[]fp.Element: - var sliceLen uint32 - sliceLen, err = dec.readUint32() - if err != nil { - return - } - if len(*t) != int(sliceLen) { - *t = make([]fp.Element, sliceLen) - } - - for i := 0; i < len(*t); i++ { - read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) - dec.n += int64(read) - if err != nil { - return - } - if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { - return - } - } - return + read64, err = (*fp.Vector)(t).ReadFrom(dec.r) + dec.n+=read64 + return case *G1Affine: // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. read, err = io.ReadFull(dec.r, buf[:SizeOfG1AffineCompressed]) @@ -419,7 +392,15 @@ func (enc *Encoder) encode{{- $.Raw}}(v interface{}) (err error) { // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap + var written64 int64 + if vw, ok := v.(io.WriterTo); ok { + written64, err = vw.WriteTo(enc.w) + enc.n += written64 + return + } + var written int + switch t := v.(type) { case *fr.Element: buf := t.Bytes() @@ -440,42 +421,23 @@ func (enc *Encoder) encode{{- $.Raw}}(v interface{}) (err error) { buf := t.{{- $.Raw}}Bytes() written, err = enc.w.Write(buf[:]) enc.n += int64(written) + return + case fr.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 + return + case fp.Vector: + written64, err = t.WriteTo(enc.w) + enc.n += written64 return case []fr.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fr.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil + written64, err = (*fr.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []fp.Element: - // write slice length - err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) - if err != nil { - return - } - enc.n += 4 - var buf [fp.Bytes]byte - for i := 0; i < len(t); i++ { - buf = t[i].Bytes() - written, err = enc.w.Write(buf[:]) - enc.n += int64(written) - if err != nil { - return - } - } - return nil - + written64, err = (*fp.Vector)(&t).WriteTo(enc.w) + enc.n += written64 + return case []G1Affine: // write slice length err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) diff --git a/internal/generator/ecc/template/tests/marshal.go.tmpl b/internal/generator/ecc/template/tests/marshal.go.tmpl index 809a1f7e0..ec69a15c1 100644 --- a/internal/generator/ecc/template/tests/marshal.go.tmpl +++ b/internal/generator/ecc/template/tests/marshal.go.tmpl @@ -12,6 +12,7 @@ import ( "math/big" "bytes" "io" + "reflect" "github.com/leanovate/gopter" "github.com/leanovate/gopter/prop" @@ -40,6 +41,7 @@ func TestEncoder(t *testing.T) { var inH []G2Affine var inI []fp.Element var inJ []fr.Element + var inK fr.Vector // set values of inputs inA = rand.Uint64() @@ -54,13 +56,15 @@ func TestEncoder(t *testing.T) { inI = make([]fp.Element, 3) inI[2] = inD.X inJ = make([]fr.Element, 0) + inK = make(fr.Vector, 42) + inK[41].SetUint64(42) // encode them, compressed and raw var buf, bufRaw bytes.Buffer enc := NewEncoder(&buf) encRaw := NewEncoder(&bufRaw, RawEncoding()) - toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ} + toEncode := []interface{}{inA, &inB, &inC, &inD, &inE, &inF, inG, inH, inI, inJ, inK} for _, v := range toEncode { if err := enc.Encode(v); err != nil { t.Fatal(err) @@ -85,8 +89,9 @@ func TestEncoder(t *testing.T) { var outH []G2Affine var outI []fp.Element var outJ []fr.Element + var outK fr.Vector - toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ} + toDecode := []interface{}{&outA, &outB, &outC, &outD, &outE, &outF, &outG, &outH, &outI, &outJ, &outK} for _, v := range toDecode { if err := dec.Decode(v); err != nil { t.Fatal(err) @@ -123,6 +128,9 @@ func TestEncoder(t *testing.T) { t.Fatal("decode(encode(slice(elements))) failed") } } + if !reflect.DeepEqual(inK, outK) { + t.Fatal("decode(encode(vector)) failed") + } if n != dec.BytesRead() { t.Fatal("bytes read don't match bytes written") }