diff --git a/arrowexec/execution/expression.go b/arrowexec/execution/expression.go index 58f55484..404052d7 100644 --- a/arrowexec/execution/expression.go +++ b/arrowexec/execution/expression.go @@ -18,6 +18,7 @@ type RecordVariable struct { } func (r *RecordVariable) Evaluate(ctx Context, record Record) (arrow.Array, error) { + // TODO: Retain array? return record.Column(r.index), nil } diff --git a/arrowexec/nodes/filter_test.go b/arrowexec/nodes/filter_test.go index 1fa6dfcf..ee9fcdbf 100644 --- a/arrowexec/nodes/filter_test.go +++ b/arrowexec/nodes/filter_test.go @@ -68,9 +68,9 @@ func BenchmarkNaiveFilter(b *testing.B) { Node: &GroupBy{ OutSchema: schema, Source: node, - KeyExprs: []int{0}, + KeyColumns: []int{0}, AggregateConstructors: []func(dt arrow.DataType) Aggregate{MakeCount}, - AggregateExprs: []int{1}, + AggregateColumns: []int{1}, }, Schema: schema, } @@ -132,9 +132,9 @@ func BenchmarkRebatchingFilter(b *testing.B) { Node: &GroupBy{ OutSchema: schema, Source: node, - KeyExprs: []int{0}, + KeyColumns: []int{0}, AggregateConstructors: []func(dt arrow.DataType) Aggregate{MakeCount}, - AggregateExprs: []int{1}, + AggregateColumns: []int{1}, }, Schema: schema, } diff --git a/arrowexec/nodes/group_by.go b/arrowexec/nodes/group_by.go index a1995e73..5391dca0 100644 --- a/arrowexec/nodes/group_by.go +++ b/arrowexec/nodes/group_by.go @@ -16,9 +16,11 @@ type GroupBy struct { OutSchema *arrow.Schema Source execution.NodeWithMeta - KeyExprs []int // For now, let's just use indices here. + // Both keys and aggregate columns have to be calculated by a preceding map. + + KeyColumns []int AggregateConstructors []func(dt arrow.DataType) Aggregate - AggregateExprs []int // For now, let's just use indices here. + AggregateColumns []int } func (g *GroupBy) Run(ctx execution.Context, produce execution.ProduceFunc) error { @@ -26,27 +28,27 @@ func (g *GroupBy) Run(ctx execution.Context, produce execution.ProduceFunc) erro entryIndices := intintmap.New(16, 0.6) aggregates := make([]Aggregate, len(g.AggregateConstructors)) for i := range aggregates { - aggregates[i] = g.AggregateConstructors[i](g.Source.Schema.Field(g.AggregateExprs[i]).Type) + aggregates[i] = g.AggregateConstructors[i](g.Source.Schema.Field(g.AggregateColumns[i]).Type) } - key := make([]Key, len(g.KeyExprs)) + key := make([]Key, len(g.KeyColumns)) for i := range key { - key[i] = MakeKey(g.Source.Schema.Field(g.KeyExprs[i]).Type) + key[i] = MakeKey(g.Source.Schema.Field(g.KeyColumns[i]).Type) } if err := g.Source.Node.Run(ctx, func(ctx execution.ProduceContext, record execution.Record) error { - getKeyHash := MakeKeyHasher(g.Source.Schema, record, g.KeyExprs) + getKeyHash := MakeKeyHasher(g.Source.Schema, record, g.KeyColumns) aggColumnConsumers := make([]func(entryIndex uint, rowIndex uint), len(aggregates)) for i := range aggColumnConsumers { - aggColumnConsumers[i] = aggregates[i].MakeColumnConsumer(record.Column(g.AggregateExprs[i])) + aggColumnConsumers[i] = aggregates[i].MakeColumnConsumer(record.Column(g.AggregateColumns[i])) } newKeyAdders := make([]func(rowIndex uint), len(key)) for i := range newKeyAdders { - newKeyAdders[i] = key[i].MakeNewKeyAdder(record.Column(g.KeyExprs[i])) + newKeyAdders[i] = key[i].MakeNewKeyAdder(record.Column(g.KeyColumns[i])) } keyEqualityCheckers := make([]func(entryIndex uint, rowIndex uint) bool, len(key)) for i := range keyEqualityCheckers { - keyEqualityCheckers[i] = key[i].MakeKeyEqualityChecker(record.Column(g.KeyExprs[i])) + keyEqualityCheckers[i] = key[i].MakeKeyEqualityChecker(record.Column(g.KeyColumns[i])) } rows := record.NumRows() diff --git a/arrowexec/nodes/group_by_test.go b/arrowexec/nodes/group_by_test.go index 8a3d1cd5..1138701d 100644 --- a/arrowexec/nodes/group_by_test.go +++ b/arrowexec/nodes/group_by_test.go @@ -13,6 +13,7 @@ import ( ) func TestGroupBy(t *testing.T) { + ctx := context.Background() allocator := memory.NewGoAllocator() schema := arrow.NewSchema( @@ -56,15 +57,15 @@ func TestGroupBy(t *testing.T) { ), Source: node, - KeyExprs: []int{0}, - AggregateExprs: []int{1}, + KeyColumns: []int{0}, + AggregateColumns: []int{1}, AggregateConstructors: []func(dt arrow.DataType) Aggregate{ MakeSum, }, }, } - if err := node.Node.Run(context.Background(), func(ctx execution.ProduceContext, record execution.Record) error { + if err := node.Node.Run(execution.Context{Context: ctx}, func(ctx execution.ProduceContext, record execution.Record) error { log.Println(record) return nil }); err != nil { @@ -92,13 +93,13 @@ func BenchmarkGroupBy(b *testing.B) { aBuilder := array.NewInt64Builder(allocator) bBuilder := array.NewInt64Builder(allocator) - for i := 0; i < execution.BatchSize; i++ { - aBuilder.Append(int64((arrayIndex*execution.BatchSize + i) % groups)) - bBuilder.Append(int64(arrayIndex*execution.BatchSize + i)) + for i := 0; i < execution.IdealBatchSize; i++ { + aBuilder.Append(int64((arrayIndex*execution.IdealBatchSize + i) % groups)) + bBuilder.Append(int64(arrayIndex*execution.IdealBatchSize + i)) } records = append(records, execution.Record{ - Record: array.NewRecord(schema, []arrow.Array{aBuilder.NewArray(), bBuilder.NewArray()}, execution.BatchSize), + Record: array.NewRecord(schema, []arrow.Array{aBuilder.NewArray(), bBuilder.NewArray()}, execution.IdealBatchSize), }) } @@ -124,8 +125,8 @@ func BenchmarkGroupBy(b *testing.B) { ), Source: node, - KeyExprs: []int{0}, - AggregateExprs: []int{1}, + KeyColumns: []int{0}, + AggregateColumns: []int{1}, AggregateConstructors: []func(dt arrow.DataType) Aggregate{ MakeSum, }, @@ -133,7 +134,7 @@ func BenchmarkGroupBy(b *testing.B) { } b.StartTimer() - if err := node.Node.Run(ctx, func(ctx execution.ProduceContext, record execution.Record) error { + if err := node.Node.Run(execution.Context{Context: ctx}, func(ctx execution.ProduceContext, record execution.Record) error { outArrays = append(outArrays, record.Record.Columns()[0]) return nil }); err != nil { @@ -161,7 +162,7 @@ func BenchmarkGroupByString(b *testing.B) { for arrayIndex := 0; arrayIndex < rounds; arrayIndex++ { aBuilder := array.NewStringBuilder(allocator) - for i := 0; i < execution.BatchSize; i++ { + for i := 0; i < execution.IdealBatchSize; i++ { switch rand.Intn(4) { case 0: aBuilder.Append("aaa") @@ -175,7 +176,7 @@ func BenchmarkGroupByString(b *testing.B) { } records = append(records, execution.Record{ - Record: array.NewRecord(schema, []arrow.Array{aBuilder.NewArray()}, execution.BatchSize), + Record: array.NewRecord(schema, []arrow.Array{aBuilder.NewArray()}, execution.IdealBatchSize), }) } @@ -201,8 +202,8 @@ func BenchmarkGroupByString(b *testing.B) { ), Source: node, - KeyExprs: []int{0}, - AggregateExprs: []int{0}, + KeyColumns: []int{0}, + AggregateColumns: []int{0}, AggregateConstructors: []func(dt arrow.DataType) Aggregate{ MakeCount, }, @@ -210,7 +211,7 @@ func BenchmarkGroupByString(b *testing.B) { } b.StartTimer() - if err := node.Node.Run(ctx, func(ctx execution.ProduceContext, record execution.Record) error { + if err := node.Node.Run(execution.Context{Context: ctx}, func(ctx execution.ProduceContext, record execution.Record) error { outArrays = append(outArrays, record.Record.Columns()[0]) return nil }); err != nil {