From 3dd3baf1d6436928061e04cf9dadd35536f6690c Mon Sep 17 00:00:00 2001 From: Jakub Martin Date: Tue, 8 Aug 2023 17:57:20 +0200 Subject: [PATCH] Split jointable into partitions. --- arrowexec/nodes/hashtable/join_hashtable.go | 70 +++++++++++-------- .../nodes/hashtable/join_hashtable_test.go | 4 +- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/arrowexec/nodes/hashtable/join_hashtable.go b/arrowexec/nodes/hashtable/join_hashtable.go index cd13d601..dc94c2ff 100644 --- a/arrowexec/nodes/hashtable/join_hashtable.go +++ b/arrowexec/nodes/hashtable/join_hashtable.go @@ -14,14 +14,20 @@ import ( "golang.org/x/sync/errgroup" ) -type JoinHashTable struct { +type JoinTable struct { + partitions []JoinTablePartition +} + +type JoinTablePartition struct { hashStartIndices *intintmap.Map hashes *array.Uint64 values execution.Record } -// BuildHashTable groups the records by their key hashes, and returns them with an index of partition starts. -func BuildHashTable(records []execution.Record, indices []int) *JoinHashTable { +// BuildJoinTable groups the records by their key hashes, and returns them with an index of partition starts. +func BuildJoinTable(records []execution.Record, indices []int) *JoinTable { + partitions := 7 // TODO: Make it the first prime number larger than the core count. + // TODO: Handle case where there are 0 records. keyHashers := make([]func(rowIndex uint) uint64, len(records)) for i, record := range records { @@ -33,46 +39,50 @@ func BuildHashTable(records []execution.Record, indices []int) *JoinHashTable { overallRowCount += int(record.NumRows()) } - hashPositionsOrdered := make([]hashRowPosition, overallRowCount) - i := 0 + hashPositionsOrdered := make([][]hashRowPosition, partitions) + for i := range hashPositionsOrdered { + hashPositionsOrdered[i] = make([]hashRowPosition, 0, overallRowCount/partitions) + } + for recordIndex, record := range records { numRows := int(record.NumRows()) for rowIndex := 0; rowIndex < numRows; rowIndex++ { hash := keyHashers[recordIndex](uint(rowIndex)) - hashPositionsOrdered[i] = hashRowPosition{ + partition := int(hash % uint64(partitions)) + hashPositionsOrdered[partition] = append(hashPositionsOrdered[partition], hashRowPosition{ hash: hash, recordIndex: recordIndex, rowIndex: rowIndex, - } - i++ + }) } } - sorts.ByUint64(SortHashPosition(hashPositionsOrdered)) - var wg sync.WaitGroup - wg.Add(2) - - var hashIndex *intintmap.Map - go func() { - hashIndex = buildHashIndex(hashPositionsOrdered) - wg.Done() - }() - - var hashesArray *array.Uint64 - go func() { - hashesArray = buildHashesArray(overallRowCount, hashPositionsOrdered) - wg.Done() - }() - - record := buildRecords(records, overallRowCount, hashPositionsOrdered) - + wg.Add(partitions) + joinTablePartitions := make([]JoinTablePartition, partitions) + for part := 0; part < partitions; part++ { + part := part + + go func() { + hashPositionsOrderedPartition := hashPositionsOrdered[part] + sorts.ByUint64(SortHashPosition(hashPositionsOrderedPartition)) + + hashIndex := buildHashIndex(hashPositionsOrderedPartition) + hashesArray := buildHashesArray(overallRowCount, hashPositionsOrderedPartition) + record := buildRecords(records, overallRowCount, hashPositionsOrderedPartition) + + joinTablePartitions[part] = JoinTablePartition{ + hashStartIndices: hashIndex, + hashes: hashesArray, + values: execution.Record{Record: record}, + } + wg.Done() + }() + } wg.Wait() - return &JoinHashTable{ - hashStartIndices: hashIndex, - hashes: hashesArray, - values: execution.Record{Record: record}, + return &JoinTable{ + partitions: joinTablePartitions, } } diff --git a/arrowexec/nodes/hashtable/join_hashtable_test.go b/arrowexec/nodes/hashtable/join_hashtable_test.go index ecd9f4e7..4de24e79 100644 --- a/arrowexec/nodes/hashtable/join_hashtable_test.go +++ b/arrowexec/nodes/hashtable/join_hashtable_test.go @@ -32,7 +32,7 @@ func TestPartition(t *testing.T) { }) } - table := BuildHashTable(records, []int{0, 1}) + table := BuildJoinTable(records, []int{0, 1}) log.Println("hashes:", table) } @@ -61,7 +61,7 @@ func BenchmarkPartitionIntegers(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - table := BuildHashTable(records, []int{0, 1}) + table := BuildJoinTable(records, []int{0, 1}) table = table } }