diff --git a/arrowexec/helpers/equality_checker.go b/arrowexec/helpers/equality_checker.go index 3af14644..7147f9d6 100644 --- a/arrowexec/helpers/equality_checker.go +++ b/arrowexec/helpers/equality_checker.go @@ -1,37 +1,38 @@ package helpers import ( + "fmt" + "github.com/apache/arrow/go/v13/arrow" "github.com/apache/arrow/go/v13/arrow/array" - "github.com/cube2222/octosql/arrowexec/execution" ) -func MakeKeyEqualityChecker(leftRecord, rightRecord execution.Record, leftKeyIndices, rightKeyIndices []int) func(leftRowIndex, rightRowIndex int) bool { - if len(leftKeyIndices) != len(rightKeyIndices) { - panic("key column count mismatch in equality checker") +func MakeRowEqualityChecker(leftKeys, rightKeys []arrow.Array) func(leftRowIndex, rightRowIndex int) bool { + if len(leftKeys) != len(rightKeys) { + panic(fmt.Errorf("key column count mismatch in equality checker: %d != %d", len(leftKeys), len(rightKeys))) } - keyColumnCount := len(leftKeyIndices) + keyColumnCount := len(leftKeys) columnEqualityCheckers := make([]func(leftRowIndex, rightRowIndex int) bool, keyColumnCount) for i := 0; i < keyColumnCount; i++ { - switch leftRecord.Column(leftKeyIndices[i]).DataType().ID() { + switch leftKeys[i].DataType().ID() { case arrow.INT64: // TODO: Handle nulls. - leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.Int64).Int64Values() - rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.Int64).Int64Values() + leftTypedArr := leftKeys[i].(*array.Int64).Int64Values() + rightTypedArr := rightKeys[i].(*array.Int64).Int64Values() columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool { return leftTypedArr[leftRowIndex] == rightTypedArr[rightRowIndex] } case arrow.FLOAT64: - leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.Float64).Float64Values() - rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.Float64).Float64Values() + leftTypedArr := leftKeys[i].(*array.Float64).Float64Values() + rightTypedArr := rightKeys[i].(*array.Float64).Float64Values() columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool { return leftTypedArr[leftRowIndex] == rightTypedArr[rightRowIndex] } case arrow.STRING: // TODO: Move to large string array. - leftTypedArr := leftRecord.Column(leftKeyIndices[i]).(*array.String) - rightTypedArr := rightRecord.Column(rightKeyIndices[i]).(*array.String) + leftTypedArr := leftKeys[i].(*array.String) + rightTypedArr := rightKeys[i].(*array.String) columnEqualityCheckers[i] = func(leftRowIndex, rightRowIndex int) bool { return leftTypedArr.Value(leftRowIndex) == rightTypedArr.Value(rightRowIndex) } diff --git a/arrowexec/helpers/rewriter.go b/arrowexec/helpers/rewriter.go index 3c439f11..4663ee17 100644 --- a/arrowexec/helpers/rewriter.go +++ b/arrowexec/helpers/rewriter.go @@ -11,19 +11,19 @@ func MakeColumnRewriter(builder array.Builder, arr arrow.Array) func(rowIndex in // TODO: Should this operate on row ranges instead of single rows? Would make low-selectivity workloads faster, as well as nested types. switch builder.Type().ID() { case arrow.INT64: - typedBuilder := builder.(*array.Int64Builder) - typedArr := arr.(*array.Int64) - return func(rowIndex int) { - typedBuilder.Append(typedArr.Value(rowIndex)) - } + return rewriterForType[int64](builder.(*array.Int64Builder), arr.(*array.Int64)) case arrow.FLOAT64: - typedBuilder := builder.(*array.Float64Builder) - typedArr := arr.(*array.Float64) - return func(rowIndex int) { - typedBuilder.Append(typedArr.Value(rowIndex)) - } + return rewriterForType[float64](builder.(*array.Float64Builder), arr.(*array.Float64)) + case arrow.STRING: + return rewriterForType[string](builder.(*array.StringBuilder), arr.(*array.String)) // TODO: Add more types. default: - panic(fmt.Errorf("unsupported type for filtering: %v", builder.Type().ID())) + panic(fmt.Errorf("unsupported type for rewriting: %v", builder.Type().ID())) + } +} + +func rewriterForType[T any, BuilderType interface{ Append(v T) }, ArrayType interface{ Value(i int) T }](builder BuilderType, arr ArrayType) func(rowIndex int) { + return func(rowIndex int) { + builder.Append(arr.Value(rowIndex)) } } diff --git a/arrowexec/nodes/hashtable/join_hashtable.go b/arrowexec/nodes/hashtable/join_hashtable.go index 27112161..b85fbccf 100644 --- a/arrowexec/nodes/hashtable/join_hashtable.go +++ b/arrowexec/nodes/hashtable/join_hashtable.go @@ -9,7 +9,7 @@ import ( "github.com/apache/arrow/go/v13/arrow/memory" "github.com/brentp/intintmap" "github.com/cube2222/octosql/arrowexec/execution" - helpers2 "github.com/cube2222/octosql/arrowexec/helpers" + "github.com/cube2222/octosql/arrowexec/helpers" "github.com/twotwotwo/sorts" "golang.org/x/sync/errgroup" ) @@ -17,37 +17,31 @@ import ( type JoinTable struct { partitions []JoinTablePartition - keyIndices, joinedKeyIndices []int - tableIsLeftSide bool + tableIsLeftSide bool } type JoinTablePartition struct { hashStartIndices *intintmap.Map hashes *array.Uint64 + keys []arrow.Array values execution.Record } -func BuildJoinTable(records []execution.Record, keyIndices, joinedKeyIndices []int, tableIsLeftSide bool) *JoinTable { - if len(keyIndices) != len(joinedKeyIndices) { - panic("table key and joined key indices don't have the same length") - } - - partitions := buildJoinTablePartitions(records, keyIndices) +func BuildJoinTable(records []execution.Record, tableKeyColumns [][]arrow.Array, tableIsLeftSide bool) *JoinTable { + partitions := buildJoinTablePartitions(records, tableKeyColumns) return &JoinTable{ - partitions: partitions, - keyIndices: keyIndices, - joinedKeyIndices: joinedKeyIndices, - tableIsLeftSide: tableIsLeftSide, + partitions: partitions, + tableIsLeftSide: tableIsLeftSide, } } -func buildJoinTablePartitions(records []execution.Record, keyIndices []int) []JoinTablePartition { +func buildJoinTablePartitions(records []execution.Record, keyColumns [][]arrow.Array) []JoinTablePartition { partitions := 7 // TODO: Make it the first prime number larger than the core count. // TODO: Handle case where there are 0 records. keyHashers := make([]func(rowIndex uint) uint64, len(records)) - for i, record := range records { - keyHashers[i] = helpers2.MakeRecordKeyHasher(record, keyIndices) + for i := range records { + keyHashers[i] = helpers.MakeRowHasher(keyColumns[i]) } var overallRowCount int @@ -85,11 +79,12 @@ func buildJoinTablePartitions(records []execution.Record, keyIndices []int) []Jo hashIndex := buildHashIndex(hashPositionsOrderedPartition) hashesArray := buildHashesArray(hashPositionsOrderedPartition) - record := buildRecords(records, hashPositionsOrderedPartition) + record, keys := buildRecords(records, keyColumns, hashPositionsOrderedPartition) joinTablePartitions[part] = JoinTablePartition{ hashStartIndices: hashIndex, hashes: hashesArray, + keys: keys, values: execution.Record{Record: record}, } wg.Done() @@ -129,10 +124,16 @@ func buildHashesArray(hashPositionsOrdered []hashRowPosition) *array.Uint64 { return hashesBuilder.NewUint64Array() } -func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosition) arrow.Record { +func buildRecords(records []execution.Record, keyColumns [][]arrow.Array, hashPositionsOrdered []hashRowPosition) (arrow.Record, []arrow.Array) { // TODO: Get allocator from argument. + // TODO: Fix when 0 records given. recordBuilder := array.NewRecordBuilder(memory.NewGoAllocator(), records[0].Schema()) recordBuilder.Reserve(len(hashPositionsOrdered)) + keyColumnsBuilder := make([]array.Builder, len(keyColumns[0])) + for recordIndex := range keyColumnsBuilder { + keyColumnsBuilder[recordIndex] = array.NewBuilder(memory.NewGoAllocator(), keyColumns[recordIndex][0].DataType()) + keyColumnsBuilder[recordIndex].Reserve(len(hashPositionsOrdered)) + } var g errgroup.Group g.SetLimit(runtime.GOMAXPROCS(0)) @@ -141,7 +142,21 @@ func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosi for columnIndex := 0; columnIndex < columnCount; columnIndex++ { columnRewriters := make([]func(rowIndex int), len(records)) for recordIndex, record := range records { - columnRewriters[recordIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(columnIndex), record.Column(columnIndex)) + columnRewriters[recordIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(columnIndex), record.Column(columnIndex)) + } + + g.Go(func() error { + for _, hashPosition := range hashPositionsOrdered { + columnRewriters[hashPosition.recordIndex](hashPosition.rowIndex) + } + return nil + }) + } + // TODO: Fix when 0 records given. + for columnIndex := 0; columnIndex < len(keyColumns[0]); columnIndex++ { + columnRewriters := make([]func(rowIndex int), len(records)) + for recordIndex := range records { + columnRewriters[recordIndex] = helpers.MakeColumnRewriter(keyColumnsBuilder[columnIndex], keyColumns[recordIndex][columnIndex]) } g.Go(func() error { @@ -152,8 +167,14 @@ func buildRecords(records []execution.Record, hashPositionsOrdered []hashRowPosi }) } g.Wait() + record := recordBuilder.NewRecord() - return record + keyColumnArrays := make([]arrow.Array, len(keyColumnsBuilder)) + for i := range keyColumnArrays { + keyColumnArrays[i] = keyColumnsBuilder[i].NewArray() + } + + return record, keyColumnArrays } type SortHashPosition []hashRowPosition @@ -174,7 +195,7 @@ func (h SortHashPosition) Key(i int) uint64 { return h[i].hash } -func (t *JoinTable) JoinWithRecord(record execution.Record, produce func(execution.Record)) { +func (t *JoinTable) JoinWithRecord(record execution.Record, keys []arrow.Array, produce func(execution.Record)) { var outFields []arrow.Field if t.tableIsLeftSide { outFields = append(outFields, t.partitions[0].values.Schema().Fields()...) @@ -185,11 +206,11 @@ func (t *JoinTable) JoinWithRecord(record execution.Record, produce func(executi } outSchema := arrow.NewSchema(outFields, nil) - recordKeyHasher := helpers2.MakeRecordKeyHasher(record, t.joinedKeyIndices) + recordKeyHasher := helpers.MakeRowHasher(keys) partitionKeyEqualityCheckers := make([]func(joinedRowIndex int, tableRowIndex int) bool, len(t.partitions)) for partitionIndex := range t.partitions { - partitionKeyEqualityCheckers[partitionIndex] = helpers2.MakeKeyEqualityChecker(record, t.partitions[partitionIndex].values, t.keyIndices, t.joinedKeyIndices) + partitionKeyEqualityCheckers[partitionIndex] = helpers.MakeRowEqualityChecker(keys, t.partitions[partitionIndex].keys) } recordBuilder := array.NewRecordBuilder(memory.NewGoAllocator(), outSchema) @@ -247,12 +268,12 @@ func (t *JoinTable) makeRecordRewriterForPartition(joinedRecord execution.Record joinedRecordColumnRewriters := make([]func(rowIndex int), len(joinedRecord.Columns())) for columnIndex := range joinedRecord.Columns() { - joinedRecordColumnRewriters[columnIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(joinedRecordColumnOffset+columnIndex), joinedRecord.Column(columnIndex)) + joinedRecordColumnRewriters[columnIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(joinedRecordColumnOffset+columnIndex), joinedRecord.Column(columnIndex)) } tableColumnRewriters := make([]func(rowIndex int), len(partition.values.Columns())) for columnIndex := range partition.values.Columns() { - tableColumnRewriters[columnIndex] = helpers2.MakeColumnRewriter(recordBuilder.Field(tableColumnOffset+columnIndex), partition.values.Column(columnIndex)) + tableColumnRewriters[columnIndex] = helpers.MakeColumnRewriter(recordBuilder.Field(tableColumnOffset+columnIndex), partition.values.Column(columnIndex)) } return func(joinedRowIndex int, tableRowIndex int) { diff --git a/arrowexec/nodes/join.go b/arrowexec/nodes/join.go index bce88e12..9dcada05 100644 --- a/arrowexec/nodes/join.go +++ b/arrowexec/nodes/join.go @@ -1,15 +1,17 @@ package nodes import ( + "fmt" "runtime" + "github.com/apache/arrow/go/v13/arrow" "github.com/cube2222/octosql/arrowexec/execution" "github.com/cube2222/octosql/arrowexec/nodes/hashtable" ) type StreamJoin struct { - Left, Right *execution.NodeWithMeta - LeftKeyIndices, RightKeyIndices []int + Left, Right *execution.NodeWithMeta + LeftKey, RightKey []execution.Expression } func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) error { @@ -23,7 +25,7 @@ func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) e leftRecordChannel <- record return nil }); err != nil { - panic("implement me") + panic(fmt.Errorf("implement error handling: %w", err)) } }() @@ -33,7 +35,7 @@ func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) e rightRecordChannel <- record return nil }); err != nil { - panic("implement me") + panic(fmt.Errorf("implement error handling: %w", err)) } }() @@ -56,26 +58,39 @@ receiveLoop: } } var tableRecords, joinedRecords []execution.Record - var tableKeyIndices, joinedKeyIndices []int + var tableKeyExprs, joinedKeyExprs []execution.Expression var joinedRecordChannel chan execution.Record var tableIsLeft bool if leftClosed { tableRecords = leftRecords joinedRecords = rightRecords - tableKeyIndices = s.LeftKeyIndices - joinedKeyIndices = s.RightKeyIndices + tableKeyExprs = s.LeftKey + joinedKeyExprs = s.RightKey joinedRecordChannel = rightRecordChannel tableIsLeft = true } else { tableRecords = rightRecords joinedRecords = leftRecords - tableKeyIndices = s.RightKeyIndices - joinedKeyIndices = s.LeftKeyIndices + tableKeyExprs = s.RightKey + joinedKeyExprs = s.LeftKey joinedRecordChannel = leftRecordChannel tableIsLeft = false } - table := hashtable.BuildJoinTable(tableRecords, tableKeyIndices, joinedKeyIndices, tableIsLeft) + tableKeyColumns := make([][]arrow.Array, len(tableRecords)) + // TODO: Parallelize. + for recordIndex, record := range tableRecords { + tableKeyColumns[recordIndex] = make([]arrow.Array, len(tableKeyExprs)) + for i, expr := range tableKeyExprs { + exprValue, err := expr.Evaluate(ctx, record) + if err != nil { + return fmt.Errorf("couldn't evaluate key expression: %w", err) + } + tableKeyColumns[recordIndex][i] = exprValue + } + } + + table := hashtable.BuildJoinTable(tableRecords, tableKeyColumns, tableIsLeft) outputRecordChannelChan := make(chan (<-chan execution.Record), runtime.GOMAXPROCS(0)) go func() { @@ -85,8 +100,16 @@ receiveLoop: outputRecordChannelChan <- outputRecordChannel go func() { + keys := make([]arrow.Array, len(joinedKeyExprs)) + for i, expr := range joinedKeyExprs { + exprValue, err := expr.Evaluate(ctx, joinedRecord) + if err != nil { + panic("implement me") + } + keys[i] = exprValue + } defer close(outputRecordChannel) - table.JoinWithRecord(joinedRecord, func(record execution.Record) { + table.JoinWithRecord(joinedRecord, keys, func(record execution.Record) { outputRecordChannel <- record }) }() diff --git a/physical/nodes.go b/physical/nodes.go index 4557ca39..4ba74410 100644 --- a/physical/nodes.go +++ b/physical/nodes.go @@ -419,6 +419,40 @@ func (node *Node) Materialize(ctx context.Context, env Environment) (*execution. }, Schema: OctoSQLToArrowSchema(node.Schema), }, nil + case NodeTypeStreamJoin: + left, err := node.StreamJoin.Left.Materialize(ctx, env) + if err != nil { + return nil, fmt.Errorf("couldn't materialize stream join left source: %w", err) + } + right, err := node.StreamJoin.Right.Materialize(ctx, env) + if err != nil { + return nil, fmt.Errorf("couldn't materialize stream join right source: %w", err) + } + leftKey := make([]execution.Expression, len(node.StreamJoin.LeftKey)) + for i := range node.StreamJoin.LeftKey { + expr, err := node.StreamJoin.LeftKey[i].Materialize(ctx, env, OctoSQLToArrowSchema(node.StreamJoin.Left.Schema)) + if err != nil { + return nil, fmt.Errorf("couldn't materialize stream join left key expression with index %d: %w", i, err) + } + leftKey[i] = expr + } + rightKey := make([]execution.Expression, len(node.StreamJoin.RightKey)) + for i := range node.StreamJoin.RightKey { + expr, err := node.StreamJoin.RightKey[i].Materialize(ctx, env, OctoSQLToArrowSchema(node.StreamJoin.Right.Schema)) + if err != nil { + return nil, fmt.Errorf("couldn't materialize stream join right key expression with index %d: %w", i, err) + } + rightKey[i] = expr + } + return &execution.NodeWithMeta{ + Node: &nodes.StreamJoin{ + Left: left, + Right: right, + LeftKey: leftKey, + RightKey: rightKey, + }, + Schema: OctoSQLToArrowSchema(node.Schema), + }, nil default: panic(fmt.Sprintf("invalid node type: %s", node.NodeType)) }