Skip to content

Commit

Permalink
Merge pull request #243 from Consensys/193-support-fieldarray
Browse files Browse the repository at this point in the history
Support `FrArray`
  • Loading branch information
DavePearce authored Jul 16, 2024
2 parents c15df51 + 1a6fafc commit 3eddf5b
Show file tree
Hide file tree
Showing 20 changed files with 443 additions and 395 deletions.
3 changes: 1 addition & 2 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ func checkTraceWithLoweringDefault(tr trace.Trace, hirSchema *hir.Schema, cfg ch
}

func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) {
//
if cfg.expand {
// Clone to prevent interefence with subsequent checks
tr = tr.Clone()
Expand All @@ -186,12 +187,10 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
return tr, err
}
}

// Perform Alignment
if err := performAlignment(false, tr, schema, cfg); err != nil {
return tr, err
}

// Apply padding (as necessary)
for n := cfg.padding.Left; n <= cfg.padding.Right; n++ {
if ptr, err := padAndCheckTrace(n, tr, schema); err != nil {
Expand Down
14 changes: 4 additions & 10 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"strings"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"

Expand Down Expand Up @@ -80,12 +79,7 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace {
if strings.HasPrefix(qName, prefix) {
ith := tr.Columns().Get(i)
// Copy column data
data := make([]*fr.Element, ith.Height())
//
for j := 0; j < int(ith.Height()); j++ {
data[j] = ith.Get(j)
}

data := ith.Data().Clone()
err := builder.Add(qName, ith.Padding(), data)
// Sanity check
if err != nil {
Expand All @@ -102,9 +96,9 @@ func listColumns(tr trace.Trace) {
tbl := util.NewTablePrinter(3, n)

for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i)
elems := fmt.Sprintf("%d rows", ith.Height())
bytes := fmt.Sprintf("%d bytes", ith.Width()*ith.Height())
ith := tr.Columns().Get(i).Data()
elems := fmt.Sprintf("%d rows", ith.Len())
bytes := fmt.Sprintf("(%d*%d) = %d bytes", ith.Len(), ith.ByteWidth(), ith.ByteWidth()*ith.Len())
tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes)
}

Expand Down
13 changes: 7 additions & 6 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,25 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
// Identify source column
source := columns.Get(p.source)
// Construct byte column data
cols := make([][]*fr.Element, n)
cols := make([]util.FrArray, n)
// Initialise columns
for i := 0; i < n; i++ {
cols[i] = make([]*fr.Element, source.Height())
// Construct a byte column for ith byte
cols[i] = util.NewFrArray(source.Height(), 1)
}
// Decompose each row of each column
for i := 0; i < int(source.Height()); i = i + 1 {
ith := decomposeIntoBytes(source.Get(i), n)
for i := uint(0); i < source.Height(); i = i + 1 {
ith := decomposeIntoBytes(source.Get(int(i)), n)
for j := 0; j < n; j++ {
cols[j][i] = ith[j]
cols[j].Set(i, ith[j])
}
}
// Determine padding values
padding := decomposeIntoBytes(source.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
ith := p.targets[i]
columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), cols[i], padding[i]))
columns.Add(ith.Context(), ith.Name(), cols[i], padding[i])
}
// Done
return nil
Expand Down
12 changes: 6 additions & 6 deletions pkg/schema/assignment/computed_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,23 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// Determine multiplied height
height := tr.Modules().Get(p.target.Context().Module()).Height() * multiplier
// Make space for computed data
data := make([]*fr.Element, height)
data := util.NewFrArray(height, 32)
// Expand the trace
for i := 0; i < len(data); i++ {
val := p.expr.EvalAt(i, tr)
for i := uint(0); i < data.Len(); i++ {
val := p.expr.EvalAt(int(i), tr)
if val != nil {
data[i] = val
data.Set(i, val)
} else {
zero := fr.NewElement(0)
data[i] = &zero
data.Set(i, &zero)
}
}
// Determine padding value. A negative row index is used here to ensure
// that all columns return their padding value which is then used to compute
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
columns.Add(trace.NewFieldColumn(p.target.Context(), p.Name(), data, padding))
columns.Add(p.target.Context(), p.Name(), data, padding)
// Done
return nil
}
18 changes: 10 additions & 8 deletions pkg/schema/assignment/interleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package assignment
import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -73,12 +71,16 @@ func (p *Interleaving) RequiredSpillage() uint {
func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
columns := tr.Columns()
ctx := p.target.Context()
// Byte width records the largest width of any column.
byte_width := uint(0)
// Ensure target column doesn't exist
for i := p.Columns(); i.HasNext(); {
name := i.Next().Name()
ith := i.Next()
// Update byte width
byte_width = max(byte_width, ith.Type().ByteWidth())
// Sanity check no column already exists with this name.
if _, ok := columns.IndexOf(ctx.Module(), name); ok {
return fmt.Errorf("interleaved column already exists ({%s})", name)
if _, ok := columns.IndexOf(ctx.Module(), ith.Name()); ok {
return fmt.Errorf("interleaved column already exists ({%s})", ith.Name())
}
}
// Determine interleaving width
Expand All @@ -90,7 +92,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// the interleaved column)
height := tr.Modules().Get(ctx.Module()).Height() * multiplier
// Construct empty array
data := make([]*fr.Element, height*width)
data := util.NewFrArray(height*width, byte_width)
// Offset just gives the column index
offset := uint(0)
// Copy interleaved data
Expand All @@ -99,7 +101,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
col := tr.Columns().Get(p.sources[i])
// Copy over
for j := uint(0); j < height; j++ {
data[offset+(j*width)] = col.Get(int(j))
data.Set(offset+(j*width), col.Get(int(j)))
}

offset++
Expand All @@ -108,7 +110,7 @@ func (p *Interleaving) ExpandTrace(tr tr.Trace) error {
// column in the interleaving.
padding := columns.Get(0).Padding()
// Colunm needs to be expanded.
columns.Add(trace.NewFieldColumn(ctx, p.target.Name(), data, padding))
columns.Add(ctx, p.target.Name(), data, padding)
//
return nil
}
33 changes: 20 additions & 13 deletions pkg/schema/assignment/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,48 +80,55 @@ func (p *LexicographicSort) ExpandTrace(tr trace.Trace) error {
// Determine how many rows to be constrained.
nrows := tr.Modules().Get(p.context.Module()).Height() * multiplier
// Initialise new data columns
delta := make([]*fr.Element, nrows)
bit := make([][]*fr.Element, ncols)
bit := make([]util.FrArray, ncols)
// Byte width records the largest width of any column.
byte_width := uint(0)

for i := 0; i < ncols; i++ {
bit[i] = make([]*fr.Element, nrows)
// TODO: following can be optimised to use a single bit per element,
// rather than an entire byte.
bit[i] = util.NewFrArray(nrows, 1)
ith := columns.Get(p.sources[i])
byte_width = max(byte_width, ith.Data().ByteWidth())
}

for i := 0; i < int(nrows); i++ {
delta := util.NewFrArray(nrows, byte_width)

for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta[i] = &zero
delta.Set(i, &zero)
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := columns.Get(p.sources[j]).Get(i - 1)
curr := columns.Get(p.sources[j]).Get(i)
prev := columns.Get(p.sources[j]).Get(int(i - 1))
curr := columns.Get(p.sources[j]).Get(int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element

bit[j][i] = &one
bit[j].Set(i, &one)
// Compute curr - prev
if p.signs[j] {
diff.Set(curr)
delta[i] = diff.Sub(&diff, prev)
delta.Set(i, diff.Sub(&diff, prev))
} else {
diff.Set(prev)
delta[i] = diff.Sub(&diff, curr)
delta.Set(i, diff.Sub(&diff, curr))
}

set = true
} else {
bit[j][i] = &zero
bit[j].Set(i, &zero)
}
}
}
// Add delta column data
first := p.targets[0]
columns.Add(trace.NewFieldColumn(first.Context(), first.Name(), delta, &zero))
columns.Add(first.Context(), first.Name(), delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
ith := p.targets[1+i]
columns.Add(trace.NewFieldColumn(ith.Context(), ith.Name(), bit[i], &zero))
columns.Add(ith.Context(), ith.Name(), bit[i], &zero)
}
// Done.
return nil
Expand Down
19 changes: 6 additions & 13 deletions pkg/schema/assignment/sorted_permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package assignment
import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
Expand Down Expand Up @@ -132,20 +131,14 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
}
}

cols := make([][]*fr.Element, len(p.sources))
cols := make([]util.FrArray, len(p.sources))
// Construct target columns
for i := 0; i < len(p.sources); i++ {
src := p.sources[i]
// Read column data to initialise permutation.
col := columns.Get(src)
// Copy column data to initialise permutation.
copy := make([]*fr.Element, col.Height())
//
for j := 0; j < int(col.Height()); j++ {
copy[j] = col.Get(j)
}
// Copy over
cols[i] = copy
// Read column data
data := columns.Get(src).Data()
// Clone it to initialise permutation.
cols[i] = data.Clone()
}
// Sort target columns
util.PermutationSort(cols, p.signs)
Expand All @@ -156,7 +149,7 @@ func (p *SortedPermutation) ExpandTrace(tr tr.Trace) error {
ith := i.Next()
dstColName := ith.Name()
srcCol := tr.Columns().Get(p.sources[index])
columns.Add(trace.NewFieldColumn(ith.Context(), dstColName, cols[index], srcCol.Padding()))
columns.Add(ith.Context(), dstColName, cols[index], srcCol.Padding())
}
//
return nil
Expand Down
13 changes: 3 additions & 10 deletions pkg/schema/constraint/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)
Expand Down Expand Up @@ -72,20 +71,14 @@ func (p *PermutationConstraint) String() string {
return fmt.Sprintf("(permutation (%s) (%s))", targets, sources)
}

func sliceColumns(columns []uint, tr trace.Trace) [][]*fr.Element {
func sliceColumns(columns []uint, tr trace.Trace) []util.FrArray {
// Allocate return array
cols := make([][]*fr.Element, len(columns))
cols := make([]util.FrArray, len(columns))
// Slice out the data
for i, n := range columns {
nth := tr.Columns().Get(n)
// Copy column data to initialise permutation.
copy := make([]*fr.Element, nth.Height())
//
for j := 0; j < int(nth.Height()); j++ {
copy[j] = nth.Get(j)
}
// Copy over
cols[i] = copy
cols[i] = nth.Data()
}
// Done
return cols
Expand Down
22 changes: 22 additions & 0 deletions pkg/schema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Type interface {
// Accept checks whether a specific value is accepted by this type
Accept(*fr.Element) bool

// Return the number of bytes required represent any element of this type.
ByteWidth() uint

// Produce a string representation of this type.
String() string
}
Expand Down Expand Up @@ -59,6 +62,19 @@ func (p *UintType) AsField() *FieldType {
return nil
}

// ByteWidth returns the number of bytes required represent any element of this
// type.
func (p *UintType) ByteWidth() uint {
m := p.nbits / 8
n := p.nbits % 8
// Check for even division
if n == 0 {
return m
}
//
return m + 1
}

// Accept determines whether a given value is an element of this type. For
// example, 123 is an element of the type u8 whilst 256 is not.
func (p *UintType) Accept(val *fr.Element) bool {
Expand Down Expand Up @@ -104,6 +120,12 @@ func (p *FieldType) AsField() *FieldType {
return p
}

// ByteWidth returns the number of bytes required represent any element of this
// type.
func (p *FieldType) ByteWidth() uint {
return 32
}

// Accept determines whether a given value is an element of this type. In
// fact, all field elements are members of this type.
func (p *FieldType) Accept(val *fr.Element) bool {
Expand Down
Loading

0 comments on commit 3eddf5b

Please sign in to comment.