Skip to content

Commit

Permalink
Continued.
Browse files Browse the repository at this point in the history
  • Loading branch information
cube2222 committed Oct 2, 2023
1 parent 81c2fe6 commit a62cfbc
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 59 deletions.
25 changes: 13 additions & 12 deletions arrowexec/helpers/equality_checker.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
package helpers

import (
"fmt"

"github.com/apache/arrow/go/v13/arrow"
"github.com/apache/arrow/go/v13/arrow/array"
"github.com/cube2222/octosql/arrowexec/execution"
)

func MakeKeyEqualityChecker(leftRecord, rightRecord execution.Record, leftKeyIndices, rightKeyIndices []int) func(leftRowIndex, rightRowIndex int) bool {
if len(leftKeyIndices) != len(rightKeyIndices) {
panic("key column count mismatch in equality checker")
func MakeRowEqualityChecker(leftKeys, rightKeys []arrow.Array) func(leftRowIndex, rightRowIndex int) bool {
if len(leftKeys) != len(rightKeys) {
panic(fmt.Errorf("key column count mismatch in equality checker: %d != %d", len(leftKeys), len(rightKeys)))
}
keyColumnCount := len(leftKeyIndices)
keyColumnCount := len(leftKeys)

columnEqualityCheckers := make([]func(leftRowIndex, rightRowIndex int) bool, keyColumnCount)
for i := 0; i < keyColumnCount; i++ {
switch leftRecord.Column(leftKeyIndices[i]).DataType().ID() {
switch leftKeys[i].DataType().ID() {
case arrow.INT64:
// TODO: Handle nulls.
leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.Int64).Int64Values()
rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.Int64).Int64Values()
leftTypedArr := leftKeys[i].(*array.Int64).Int64Values()
rightTypedArr := rightKeys[i].(*array.Int64).Int64Values()
columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool {
return leftTypedArr[leftRowIndex] == rightTypedArr[rightRowIndex]
}
case arrow.FLOAT64:
leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.Float64).Float64Values()
rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.Float64).Float64Values()
leftTypedArr := leftKeys[i].(*array.Float64).Float64Values()
rightTypedArr := rightKeys[i].(*array.Float64).Float64Values()
columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool {
return leftTypedArr[leftRowIndex] == rightTypedArr[rightRowIndex]
}
case arrow.STRING:
// TODO: Move to large string array.
leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.String)
rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.String)
leftTypedArr := leftKeys[i].(*array.String)
rightTypedArr := rightKeys[i].(*array.String)
columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool {
return leftTypedArr.Value(leftRowIndex) == rightTypedArr.Value(rightRowIndex)
}
Expand Down
22 changes: 11 additions & 11 deletions arrowexec/helpers/rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ func MakeColumnRewriter(builder array.Builder, arr arrow.Array) func(rowIndex in
// TODO: Should this operate on row ranges instead of single rows? Would make low-selectivity workloads faster, as well as nested types.
switch builder.Type().ID() {
case arrow.INT64:
typedBuilder := builder.(*array.Int64Builder)
typedArr := arr.(*array.Int64)
return func(rowIndex int) {
typedBuilder.Append(typedArr.Value(rowIndex))
}
return rewriterForType[int64](builder.(*array.Int64Builder), arr.(*array.Int64))
case arrow.FLOAT64:
typedBuilder := builder.(*array.Float64Builder)
typedArr := arr.(*array.Float64)
return func(rowIndex int) {
typedBuilder.Append(typedArr.Value(rowIndex))
}
return rewriterForType[float64](builder.(*array.Float64Builder), arr.(*array.Float64))
case arrow.STRING:
return rewriterForType[string](builder.(*array.StringBuilder), arr.(*array.String))
// TODO: Add more types.
default:
panic(fmt.Errorf("unsupported type for filtering: %v", builder.Type().ID()))
panic(fmt.Errorf("unsupported type for rewriting: %v", builder.Type().ID()))
}
}

func rewriterForType[T any, BuilderType interface{ Append(v T) }, ArrayType interface{ Value(i int) T }](builder BuilderType, arr ArrayType) func(rowIndex int) {
return func(rowIndex int) {
builder.Append(arr.Value(rowIndex))
}
}
71 changes: 46 additions & 25 deletions arrowexec/nodes/hashtable/join_hashtable.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,39 @@ import (
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/brentp/intintmap"
"github.com/cube2222/octosql/arrowexec/execution"
helpers2 "github.com/cube2222/octosql/arrowexec/helpers"
"github.com/cube2222/octosql/arrowexec/helpers"
"github.com/twotwotwo/sorts"
"golang.org/x/sync/errgroup"
)

type JoinTable struct {
partitions []JoinTablePartition

keyIndices, joinedKeyIndices []int
tableIsLeftSide bool
tableIsLeftSide bool
}

type JoinTablePartition struct {
hashStartIndices *intintmap.Map
hashes *array.Uint64
keys []arrow.Array
values execution.Record
}

func BuildJoinTable(records []execution.Record, keyIndices, joinedKeyIndices []int, tableIsLeftSide bool) *JoinTable {
if len(keyIndices) != len(joinedKeyIndices) {
panic("table key and joined key indices don't have the same length")
}

partitions := buildJoinTablePartitions(records, keyIndices)
func BuildJoinTable(records []execution.Record, tableKeyColumns [][]arrow.Array, tableIsLeftSide bool) *JoinTable {
partitions := buildJoinTablePartitions(records, tableKeyColumns)
return &JoinTable{
partitions: partitions,
keyIndices: keyIndices,
joinedKeyIndices: joinedKeyIndices,
tableIsLeftSide: tableIsLeftSide,
partitions: partitions,
tableIsLeftSide: tableIsLeftSide,
}
}

func buildJoinTablePartitions(records []execution.Record, keyIndices []int) []JoinTablePartition {
func buildJoinTablePartitions(records []execution.Record, keyColumns [][]arrow.Array) []JoinTablePartition {
partitions := 7 // TODO: Make it the first prime number larger than the core count.

// TODO: Handle case where there are 0 records.
keyHashers := make([]func(rowIndex uint) uint64, len(records))
for i, record := range records {
keyHashers[i] = helpers2.MakeRecordKeyHasher(record, keyIndices)
for i := range records {
keyHashers[i] = helpers.MakeRowHasher(keyColumns[i])
}

var overallRowCount int
Expand Down Expand Up @@ -85,11 +79,12 @@ func buildJoinTablePartitions(records []execution.Record, keyIndices []int) []Jo

hashIndex := buildHashIndex(hashPositionsOrderedPartition)
hashesArray := buildHashesArray(hashPositionsOrderedPartition)
record := buildRecords(records, hashPositionsOrderedPartition)
record, keys := buildRecords(records, keyColumns, hashPositionsOrderedPartition)

joinTablePartitions[part] = JoinTablePartition{
hashStartIndices: hashIndex,
hashes: hashesArray,
keys: keys,
values: execution.Record{Record: record},
}
wg.Done()
Expand Down Expand Up @@ -129,10 +124,16 @@ func buildHashesArray(hashPositionsOrdered []hashRowPosition) *array.Uint64 {
return hashesBuilder.NewUint64Array()
}

func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosition) arrow.Record {
func buildRecords(records []execution.Record, keyColumns [][]arrow.Array, hashPositionsOrdered []hashRowPosition) (arrow.Record, []arrow.Array) {
// TODO: Get allocator from argument.
// TODO: Fix when 0 records given.
recordBuilder := array.NewRecordBuilder(memory.NewGoAllocator(), records[0].Schema())
recordBuilder.Reserve(len(hashPositionsOrdered))
keyColumnsBuilder := make([]array.Builder, len(keyColumns[0]))
for recordIndex := range keyColumnsBuilder {
keyColumnsBuilder[recordIndex] = array.NewBuilder(memory.NewGoAllocator(), keyColumns[recordIndex][0].DataType())
keyColumnsBuilder[recordIndex].Reserve(len(hashPositionsOrdered))
}

var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
Expand All @@ -141,7 +142,21 @@ func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosi
for columnIndex := 0; columnIndex < columnCount; columnIndex++ {
columnRewriters := make([]func(rowIndex int), len(records))
for recordIndex, record := range records {
columnRewriters[recordIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(columnIndex), record.Column(columnIndex))
columnRewriters[recordIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(columnIndex), record.Column(columnIndex))
}

g.Go(func() error {
for _, hashPosition := range hashPositionsOrdered {
columnRewriters[hashPosition.recordIndex](hashPosition.rowIndex)
}
return nil
})
}
// TODO: Fix when 0 records given.
for columnIndex := 0; columnIndex < len(keyColumns[0]); columnIndex++ {
columnRewriters := make([]func(rowIndex int), len(records))
for recordIndex := range records {
columnRewriters[recordIndex] = helpers.MakeColumnRewriter(keyColumnsBuilder[columnIndex], keyColumns[recordIndex][columnIndex])
}

g.Go(func() error {
Expand All @@ -152,8 +167,14 @@ func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosi
})
}
g.Wait()

record := recordBuilder.NewRecord()
return record
keyColumnArrays := make([]arrow.Array, len(keyColumnsBuilder))
for i := range keyColumnArrays {
keyColumnArrays[i] = keyColumnsBuilder[i].NewArray()
}

return record, keyColumnArrays
}

type SortHashPosition []hashRowPosition
Expand All @@ -174,7 +195,7 @@ func (h SortHashPosition) Key(i int) uint64 {
return h[i].hash
}

func (t *JoinTable) JoinWithRecord(record execution.Record, produce func(execution.Record)) {
func (t *JoinTable) JoinWithRecord(record execution.Record, keys []arrow.Array, produce func(execution.Record)) {
var outFields []arrow.Field
if t.tableIsLeftSide {
outFields = append(outFields, t.partitions[0].values.Schema().Fields()...)
Expand All @@ -185,11 +206,11 @@ func (t *JoinTable) JoinWithRecord(record execution.Record, produce func(executi
}
outSchema := arrow.NewSchema(outFields, nil)

recordKeyHasher := helpers2.MakeRecordKeyHasher(record, t.joinedKeyIndices)
recordKeyHasher := helpers.MakeRowHasher(keys)

partitionKeyEqualityCheckers := make([]func(joinedRowIndex int, tableRowIndex int) bool, len(t.partitions))
for partitionIndex := range t.partitions {
partitionKeyEqualityCheckers[partitionIndex] = helpers2.MakeKeyEqualityChecker(record, t.partitions[partitionIndex].values, t.keyIndices, t.joinedKeyIndices)
partitionKeyEqualityCheckers[partitionIndex] = helpers.MakeRowEqualityChecker(keys, t.partitions[partitionIndex].keys)
}

recordBuilder := array.NewRecordBuilder(memory.NewGoAllocator(), outSchema)
Expand Down Expand Up @@ -247,12 +268,12 @@ func (t *JoinTable) makeRecordRewriterForPartition(joinedRecord execution.Record

joinedRecordColumnRewriters := make([]func(rowIndex int), len(joinedRecord.Columns()))
for columnIndex := range joinedRecord.Columns() {
joinedRecordColumnRewriters[columnIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(joinedRecordColumnOffset+columnIndex), joinedRecord.Column(columnIndex))
joinedRecordColumnRewriters[columnIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(joinedRecordColumnOffset+columnIndex), joinedRecord.Column(columnIndex))
}

tableColumnRewriters := make([]func(rowIndex int), len(partition.values.Columns()))
for columnIndex := range partition.values.Columns() {
tableColumnRewriters[columnIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(tableColumnOffset+columnIndex), partition.values.Column(columnIndex))
tableColumnRewriters[columnIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(tableColumnOffset+columnIndex), partition.values.Column(columnIndex))
}

return func(joinedRowIndex int, tableRowIndex int) {
Expand Down
45 changes: 34 additions & 11 deletions arrowexec/nodes/join.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package nodes

import (
"fmt"
"runtime"

"github.com/apache/arrow/go/v13/arrow"
"github.com/cube2222/octosql/arrowexec/execution"
"github.com/cube2222/octosql/arrowexec/nodes/hashtable"
)

type StreamJoin struct {
Left, Right *execution.NodeWithMeta
LeftKeyIndices, RightKeyIndices []int
Left, Right *execution.NodeWithMeta
LeftKey, RightKey []execution.Expression
}

func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) error {
Expand All @@ -23,7 +25,7 @@ func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) e
leftRecordChannel <- record
return nil
}); err != nil {
panic("implement me")
panic(fmt.Errorf("implement error handling: %w", err))
}
}()

Expand All @@ -33,7 +35,7 @@ func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) e
rightRecordChannel <- record
return nil
}); err != nil {
panic("implement me")
panic(fmt.Errorf("implement error handling: %w", err))
}
}()

Expand All @@ -56,26 +58,39 @@ receiveLoop:
}
}
var tableRecords, joinedRecords []execution.Record
var tableKeyIndices, joinedKeyIndices []int
var tableKeyExprs, joinedKeyExprs []execution.Expression
var joinedRecordChannel chan execution.Record
var tableIsLeft bool
if leftClosed {
tableRecords = leftRecords
joinedRecords = rightRecords
tableKeyIndices = s.LeftKeyIndices
joinedKeyIndices = s.RightKeyIndices
tableKeyExprs = s.LeftKey
joinedKeyExprs = s.RightKey
joinedRecordChannel = rightRecordChannel
tableIsLeft = true
} else {
tableRecords = rightRecords
joinedRecords = leftRecords
tableKeyIndices = s.RightKeyIndices
joinedKeyIndices = s.LeftKeyIndices
tableKeyExprs = s.RightKey
joinedKeyExprs = s.LeftKey
joinedRecordChannel = leftRecordChannel
tableIsLeft = false
}

table := hashtable.BuildJoinTable(tableRecords, tableKeyIndices, joinedKeyIndices, tableIsLeft)
tableKeyColumns := make([][]arrow.Array, len(tableRecords))
// TODO: Parallelize.
for recordIndex, record := range tableRecords {
tableKeyColumns[recordIndex] = make([]arrow.Array, len(tableKeyExprs))
for i, expr := range tableKeyExprs {
exprValue, err := expr.Evaluate(ctx, record)
if err != nil {
return fmt.Errorf("couldn't evaluate key expression: %w", err)
}
tableKeyColumns[recordIndex][i] = exprValue
}
}

table := hashtable.BuildJoinTable(tableRecords, tableKeyColumns, tableIsLeft)

outputRecordChannelChan := make(chan (<-chan execution.Record), runtime.GOMAXPROCS(0))
go func() {
Expand All @@ -85,8 +100,16 @@ receiveLoop:
outputRecordChannelChan <- outputRecordChannel

go func() {
keys := make([]arrow.Array, len(joinedKeyExprs))
for i, expr := range joinedKeyExprs {
exprValue, err := expr.Evaluate(ctx, joinedRecord)
if err != nil {
panic("implement me")
}
keys[i] = exprValue
}
defer close(outputRecordChannel)
table.JoinWithRecord(joinedRecord, func(record execution.Record) {
table.JoinWithRecord(joinedRecord, keys, func(record execution.Record) {
outputRecordChannel <- record
})
}()
Expand Down
34 changes: 34 additions & 0 deletions physical/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,40 @@ func (node *Node) Materialize(ctx context.Context, env Environment) (*execution.
},
Schema: OctoSQLToArrowSchema(node.Schema),
}, nil
case NodeTypeStreamJoin:
left, err := node.StreamJoin.Left.Materialize(ctx, env)
if err != nil {
return nil, fmt.Errorf("couldn't materialize stream join left source: %w", err)
}
right, err := node.StreamJoin.Right.Materialize(ctx, env)
if err != nil {
return nil, fmt.Errorf("couldn't materialize stream join right source: %w", err)
}
leftKey := make([]execution.Expression, len(node.StreamJoin.LeftKey))
for i := range node.StreamJoin.LeftKey {
expr, err := node.StreamJoin.LeftKey[i].Materialize(ctx, env, OctoSQLToArrowSchema(node.StreamJoin.Left.Schema))
if err != nil {
return nil, fmt.Errorf("couldn't materialize stream join left key expression with index %d: %w", i, err)
}
leftKey[i] = expr
}
rightKey := make([]execution.Expression, len(node.StreamJoin.RightKey))
for i := range node.StreamJoin.RightKey {
expr, err := node.StreamJoin.RightKey[i].Materialize(ctx, env, OctoSQLToArrowSchema(node.StreamJoin.Right.Schema))
if err != nil {
return nil, fmt.Errorf("couldn't materialize stream join right key expression with index %d: %w", i, err)
}
rightKey[i] = expr
}
return &execution.NodeWithMeta{
Node: &nodes.StreamJoin{
Left: left,
Right: right,
LeftKey: leftKey,
RightKey: rightKey,
},
Schema: OctoSQLToArrowSchema(node.Schema),
}, nil
default:
panic(fmt.Sprintf("invalid node type: %s", node.NodeType))
}
Expand Down

0 comments on commit a62cfbc

Please sign in to comment.