diff --git a/arrowexec/nodes/join.go b/arrowexec/nodes/join.go new file mode 100644 index 00000000..ebeebe84 --- /dev/null +++ b/arrowexec/nodes/join.go @@ -0,0 +1,188 @@ +package nodes + +import ( + "runtime" + "sync" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/brentp/intintmap" + "github.com/cube2222/octosql/arrowexec/execution" + "github.com/twotwotwo/sorts" + "golang.org/x/sync/errgroup" +) + +type StreamJoin struct { + Left, Right execution.NodeWithMeta + LeftKeyIndices, RightKeyIndices []int +} + +func (s *StreamJoin) Run(ctx execution.Context, produce execution.ProduceFunc) error { + // TODO: The channel should also be able to pass errors. + leftRecordChannel := make(chan execution.Record, 8) + rightRecordChannel := make(chan execution.Record, 8) + + go func() { + if err := s.Left.Node.Run(ctx, func(produceCtx execution.ProduceContext, record execution.Record) error { + leftRecordChannel <- record + return nil + }); err != nil { + panic("implement me") + } + }() + + go func() { + if err := s.Right.Node.Run(ctx, func(produceCtx execution.ProduceContext, record execution.Record) error { + rightRecordChannel <- record + return nil + }); err != nil { + panic("implement me") + } + }() + + var leftRecords, rightRecords []execution.Record +receiveLoop: + for { + select { + case record, ok := <-leftRecordChannel: + if !ok { + break receiveLoop + } + leftRecords = append(leftRecords, record) + case record, ok := <-rightRecordChannel: + if !ok { + break receiveLoop + } + rightRecords = append(rightRecords, record) + } + } + + return nil +} + +// PartitionByHash groups the records by their key hashes, and returns them with an index of partition starts. +func PartitionByHash(records []execution.Record, indices []int) (*intintmap.Map, *array.Uint64, execution.Record) { + // TODO: Handle case where there are 0 records. + keyHashers := make([]func(rowIndex uint) uint64, len(records)) + for i, record := range records { + keyHashers[i] = MakeKeyHasher(record, indices) + } + + var overallRowCount int + for _, record := range records { + overallRowCount += int(record.NumRows()) + } + // hashCounts := getHashCounts(records, keyHashers) + // hashNextEntryPosition, hashIndex := calculateNextEntryPositionAndHashIndex(hashCounts) + // hashPositionsOrdered := buildHashPositionsOrdered(records, overallRowCount, keyHashers, hashNextEntryPosition) + + hashPositionsOrdered := make([]hashRowPosition, overallRowCount) + i := 0 + for recordIndex, record := range records { + numRows := int(record.NumRows()) + for rowIndex := 0; rowIndex < numRows; rowIndex++ { + hash := keyHashers[recordIndex](uint(rowIndex)) + hashPositionsOrdered[i] = hashRowPosition{ + hash: hash, + recordIndex: recordIndex, + rowIndex: rowIndex, + } + i++ + } + } + + sorts.ByUint64(SortHashPosition(hashPositionsOrdered)) + + var wg sync.WaitGroup + wg.Add(2) + + var hashIndex *intintmap.Map + go func() { + hashIndex = buildHashIndex(hashPositionsOrdered) + wg.Done() + }() + + var hashesArray *array.Uint64 + go func() { + hashesArray = buildHashesArray(overallRowCount, hashPositionsOrdered) + wg.Done() + }() + + record := buildRecords(records, overallRowCount, hashPositionsOrdered) + + wg.Wait() + + return hashIndex, hashesArray, execution.Record{Record: record} +} + +func buildHashIndex(hashPositionsOrdered []hashRowPosition) *intintmap.Map { + hashIndex := intintmap.New(1024, 0.6) + hashIndex.Put(int64(hashPositionsOrdered[0].hash), 0) + for i := 1; i < len(hashPositionsOrdered); i++ { + if hashPositionsOrdered[i].hash != hashPositionsOrdered[i-1].hash { + hashIndex.Put(int64(hashPositionsOrdered[i].hash), int64(i)) + } + } + return hashIndex +} + +type hashRowPosition struct { + hash uint64 + recordIndex int + rowIndex int +} + +func buildHashesArray(overallRowCount int, hashPositionsOrdered []hashRowPosition) *array.Uint64 { + hashesBuilder := array.NewUint64Builder(memory.NewGoAllocator()) // TODO: Get allocator from argument. + hashesBuilder.Reserve(overallRowCount) + for _, hashPosition := range hashPositionsOrdered { + hashesBuilder.UnsafeAppend(hashPosition.hash) + } + return hashesBuilder.NewUint64Array() +} + +func buildRecords(records []execution.Record, overallRowCount int, hashPositionsOrdered []hashRowPosition) arrow.Record { + // TODO: Get allocator from argument. + recordBuilder := array.NewRecordBuilder(memory.NewGoAllocator(), records[0].Schema()) + recordBuilder.Reserve(overallRowCount) + + var g errgroup.Group + g.SetLimit(runtime.GOMAXPROCS(0)) + + columnCount := len(recordBuilder.Fields()) + for columnIndex := 0; columnIndex < columnCount; columnIndex++ { + columnRewriters := make([]func(rowIndex int), len(records)) + for recordIndex, record := range records { + columnRewriters[recordIndex] = MakeColumnRewriter(recordBuilder.Field(columnIndex), record.Column(columnIndex)) + } + + g.Go(func() error { + for _, hashPosition := range hashPositionsOrdered { + columnRewriters[hashPosition.recordIndex](hashPosition.rowIndex) + } + return nil + }) + } + g.Wait() + record := recordBuilder.NewRecord() + return record +} + +type SortHashPosition []hashRowPosition + +func (h SortHashPosition) Len() int { + return len(h) +} + +func (h SortHashPosition) Less(i, j int) bool { + return h[i].hash < h[j].hash +} + +func (h SortHashPosition) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h SortHashPosition) Key(i int) uint64 { + return h[i].hash +} diff --git a/arrowexec/nodes/join_test.go b/arrowexec/nodes/join_test.go new file mode 100644 index 00000000..028f60e8 --- /dev/null +++ b/arrowexec/nodes/join_test.go @@ -0,0 +1,74 @@ +package nodes + +import ( + "log" + "math/rand" + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/array" + "github.com/apache/arrow/go/v13/arrow/memory" + "github.com/cube2222/octosql/arrowexec/execution" +) + +func TestPartition(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64}, + {Name: "b", Type: arrow.PrimitiveTypes.Int64}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64}, + }, nil) + + var records []execution.Record + recordBuilder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + for i := 0; i < 5; i++ { + for j := 0; j < 10; j++ { + recordBuilder.Field(0).(*array.Int64Builder).Append(int64(rand.Intn(5))) + recordBuilder.Field(1).(*array.Int64Builder).Append(int64(rand.Intn(2))) + recordBuilder.Field(2).(*array.Int64Builder).Append(int64(rand.Intn(7))) + } + + records = append(records, execution.Record{ + Record: recordBuilder.NewRecord(), + }) + } + + hashIndex, hashes, partitioned := PartitionByHash(records, []int{0, 1}) + + log.Println("hashIndex:") + for pair := range hashIndex.Items() { + log.Println(pair) + } + log.Println("hashes:", hashes) + log.Println("partitioned:", partitioned) +} + +func BenchmarkPartitionIntegers(b *testing.B) { + b.StopTimer() + schema := arrow.NewSchema([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64}, + {Name: "b", Type: arrow.PrimitiveTypes.Int64}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64}, + }, nil) + + var records []execution.Record + recordBuilder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + for i := 0; i < 128; i++ { + for j := 0; j < execution.IdealBatchSize; j++ { + recordBuilder.Field(0).(*array.Int64Builder).Append(int64(rand.Intn(1024 * 8))) + recordBuilder.Field(1).(*array.Int64Builder).Append(int64(rand.Intn(4))) + recordBuilder.Field(2).(*array.Int64Builder).Append(int64(rand.Intn(7))) + } + + records = append(records, execution.Record{ + Record: recordBuilder.NewRecord(), + }) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + hashIndex, hashes, partitioned := PartitionByHash(records, []int{0, 1}) + hashIndex = hashIndex + hashes = hashes + partitioned = partitioned + } +}