diff --git a/aggregates/distinct.go b/aggregates/distinct.go index c0cc4498..47a25e84 100644 --- a/aggregates/distinct.go +++ b/aggregates/distinct.go @@ -1,8 +1,11 @@ package aggregates import ( - "github.com/tidwall/btree" + "hash/fnv" + "github.com/zyedidia/generic/hashmap" + + "github.com/cube2222/octosql/execution" "github.com/cube2222/octosql/execution/nodes" "github.com/cube2222/octosql/octosql" "github.com/cube2222/octosql/physical" @@ -22,33 +25,36 @@ func DistinctAggregateOverloads(overloads []physical.AggregateDescriptor) []phys } type Distinct struct { - items *btree.Generic[*distinctKey] + items *hashmap.Map[octosql.Value, *distinctKey] wrapped nodes.Aggregate } func NewDistinctPrototype(wrapped func() nodes.Aggregate) func() nodes.Aggregate { return func() nodes.Aggregate { return &Distinct{ - items: btree.NewGenericOptions(func(key, than *distinctKey) bool { - return key.value.Compare(than.value) == -1 - }, btree.Options{NoLocks: true}), + items: hashmap.New[octosql.Value, *distinctKey]( + execution.BTreeDefaultDegree, + func(a, b octosql.Value) bool { + return a.Compare(b) == 0 + }, func(v octosql.Value) uint64 { + hash := fnv.New64() + v.Hash(hash) + return hash.Sum64() + }), wrapped: wrapped(), } } } type distinctKey struct { - value octosql.Value count int } func (c *Distinct) Add(retraction bool, value octosql.Value) bool { - var hint btree.PathHint - - item, ok := c.items.GetHint(&distinctKey{value: value}, &hint) + item, ok := c.items.Get(value) if !ok { - item = &distinctKey{value: value, count: 0} - c.items.SetHint(item, &hint) + item = &distinctKey{count: 0} + c.items.Put(value, item) } if !retraction { item.count++ @@ -58,10 +64,10 @@ func (c *Distinct) Add(retraction bool, value octosql.Value) bool { if item.count == 1 && !retraction { c.wrapped.Add(false, value) } else if item.count == 0 { - c.items.DeleteHint(item, &hint) + c.items.Remove(value) c.wrapped.Add(true, value) } - return c.items.Len() == 0 + return c.items.Size() == 0 } func (c *Distinct) Trigger() octosql.Value { diff --git a/execution/nodes/distinct.go b/execution/nodes/distinct.go index 2bfedf1f..bde6ebad 100644 --- a/execution/nodes/distinct.go +++ b/execution/nodes/distinct.go @@ -2,8 +2,9 @@ package nodes import ( "fmt" + "hash/fnv" - "github.com/tidwall/btree" + "github.com/zyedidia/generic/hashmap" . "github.com/cube2222/octosql/execution" "github.com/cube2222/octosql/octosql" @@ -20,30 +21,33 @@ func NewDistinct(source Node) *Distinct { } type distinctItem struct { - Values []octosql.Value - Count int + Count int } func (o *Distinct) Run(execCtx ExecutionContext, produce ProduceFn, metaSend MetaSendFn) error { - recordCounts := btree.NewGenericOptions(func(item, than *distinctItem) bool { - for i := 0; i < len(item.Values); i++ { - if comp := item.Values[i].Compare(than.Values[i]); comp != 0 { - return comp == -1 + recordCounts := hashmap.New[[]octosql.Value, *distinctItem]( + BTreeDefaultDegree, + func(a, b []octosql.Value) bool { + for i := range a { + if a[i].Compare(b[i]) != 0 { + return false + } } - } - - return false - }, btree.Options{ - NoLocks: true, - }) + return true + }, func(k []octosql.Value) uint64 { + hash := fnv.New64() + for _, v := range k { + v.Hash(hash) + } + return hash.Sum64() + }) o.source.Run( execCtx, func(ctx ProduceContext, record Record) error { - item, ok := recordCounts.Get(&distinctItem{Values: record.Values}) + item, ok := recordCounts.Get(record.Values) if !ok { item = &distinctItem{ - Values: record.Values, - Count: 0, + Count: 0, } } if !record.Retraction { @@ -52,18 +56,18 @@ func (o *Distinct) Run(execCtx ExecutionContext, produce ProduceFn, metaSend Met item.Count-- } if item.Count > 0 { - // New record. if !record.Retraction && item.Count == 1 { + // New record. if err := produce(ctx, record); err != nil { return fmt.Errorf("couldn't produce new record: %w", err) } - recordCounts.Set(item) + recordCounts.Put(record.Values, item) } } else { if err := produce(ctx, record); err != nil { return fmt.Errorf("couldn't retract record record: %w", err) } - recordCounts.Delete(item) + recordCounts.Remove(record.Values) } return nil }, diff --git a/execution/nodes/simple_group_by.go b/execution/nodes/simple_group_by.go index ef6cf8f4..24328736 100644 --- a/execution/nodes/simple_group_by.go +++ b/execution/nodes/simple_group_by.go @@ -2,9 +2,10 @@ package nodes import ( "fmt" + "hash/fnv" "time" - "github.com/google/btree" + "github.com/zyedidia/generic/hashmap" . "github.com/cube2222/octosql/execution" "github.com/cube2222/octosql/octosql" @@ -32,9 +33,30 @@ func NewSimpleGroupBy( } } +type hashmapAggregatesItem struct { + Aggregates []Aggregate + + // AggregatedSetSize omits NULL inputs. + AggregatedSetSize []int + + // OverallRecordCount counts all records minus retractions. + OverallRecordCount int +} + func (g *SimpleGroupBy) Run(ctx ExecutionContext, produce ProduceFn, metaSend MetaSendFn) error { - aggregates := btree.NewG[*aggregatesItem](BTreeDefaultDegree, func(a, b *aggregatesItem) bool { - return CompareValueSlices(a.GroupKey, b.GroupKey) + aggregates := hashmap.New[GroupKey, *hashmapAggregatesItem](BTreeDefaultDegree, func(a, b GroupKey) bool { + for i := range a { + if a[i].Compare(b[i]) != 0 { + return false + } + } + return true + }, func(k GroupKey) uint64 { + hash := fnv.New64() + for _, v := range k { + v.Hash(hash) + } + return hash.Sum64() }) if err := g.source.Run(ctx, func(produceCtx ProduceContext, record Record) error { @@ -50,7 +72,7 @@ func (g *SimpleGroupBy) Run(ctx ExecutionContext, produce ProduceFn, metaSend Me } { - itemTyped, ok := aggregates.Get(&aggregatesItem{GroupKey: key}) + itemTyped, ok := aggregates.Get(key) if !ok { newAggregates := make([]Aggregate, len(g.aggregatePrototypes)) @@ -58,8 +80,8 @@ func (g *SimpleGroupBy) Run(ctx ExecutionContext, produce ProduceFn, metaSend Me newAggregates[i] = g.aggregatePrototypes[i]() } - itemTyped = &aggregatesItem{GroupKey: key, Aggregates: newAggregates, AggregatedSetSize: make([]int, len(g.aggregatePrototypes))} - aggregates.ReplaceOrInsert(itemTyped) + itemTyped = &hashmapAggregatesItem{Aggregates: newAggregates, AggregatedSetSize: make([]int, len(g.aggregatePrototypes))} + aggregates.Put(key, itemTyped) } if !record.Retraction { @@ -84,7 +106,7 @@ func (g *SimpleGroupBy) Run(ctx ExecutionContext, produce ProduceFn, metaSend Me } if itemTyped.OverallRecordCount == 0 { - aggregates.Delete(itemTyped) + aggregates.Remove(key) } } @@ -96,26 +118,35 @@ func (g *SimpleGroupBy) Run(ctx ExecutionContext, produce ProduceFn, metaSend Me } var err error - aggregates.Ascend(func(itemTyped *aggregatesItem) bool { - key := itemTyped.GroupKey - - outputValues := make([]octosql.Value, len(key)+len(g.aggregateExprs)) - copy(outputValues, key) - - for i := range itemTyped.Aggregates { - if itemTyped.AggregatedSetSize[i] > 0 { - outputValues[len(key)+i] = itemTyped.Aggregates[i].Trigger() - } else { - outputValues[len(key)+i] = octosql.NewNull() + func() { + type stopEach struct{} + defer func() { + msg := recover() + if msg == nil { + return + } + if _, ok := msg.(stopEach); ok { + return + } + panic(msg) + }() + aggregates.Each(func(key GroupKey, itemTyped *hashmapAggregatesItem) { + outputValues := make([]octosql.Value, len(key)+len(g.aggregateExprs)) + copy(outputValues, key) + + for i := range itemTyped.Aggregates { + if itemTyped.AggregatedSetSize[i] > 0 { + outputValues[len(key)+i] = itemTyped.Aggregates[i].Trigger() + } else { + outputValues[len(key)+i] = octosql.NewNull() + } } - } - - if err = produce(ProduceFromExecutionContext(ctx), NewRecord(outputValues, false, time.Time{})); err != nil { - return false - } - return true - }) + if err = produce(ProduceFromExecutionContext(ctx), NewRecord(outputValues, false, time.Time{})); err != nil { + panic(stopEach{}) + } + }) + }() return err } diff --git a/go.mod b/go.mod index 15b34a62..638b1214 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/stretchr/testify v1.7.0 github.com/tidwall/btree v1.3.1 github.com/valyala/fastjson v1.6.3 + github.com/zyedidia/generic v1.1.0 golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd google.golang.org/grpc v1.42.0 google.golang.org/protobuf v1.27.1 @@ -58,6 +59,7 @@ require ( github.com/pkg/term v1.2.0-beta.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/segmentio/encoding v0.3.5 // indirect + github.com/segmentio/fasthash v1.0.3 // indirect github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/ulikunitz/xz v0.5.10 // indirect diff --git a/go.sum b/go.sum index 4f8660d2..12ec6683 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= github.com/segmentio/encoding v0.3.5 h1:UZEiaZ55nlXGDL92scoVuw00RmiRCazIEmvPSbSvt8Y= github.com/segmentio/encoding v0.3.5/go.mod h1:n0JeuIqEQrQoPDGsjo8UNd1iA0U8d8+oHAA4E3G3OxM= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= @@ -183,6 +185,8 @@ github.com/valyala/fastjson v1.6.3 h1:tAKFnnwmeMGPbwJ7IwxcTPCNr3uIzoIj3/Fh90ra4x github.com/valyala/fastjson v1.6.3/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= +github.com/zyedidia/generic v1.1.0 h1:G9kbhNFCZhf2d9SC53RkHQdmMoPwImguLOGx9DW2ADM= +github.com/zyedidia/generic v1.1.0/go.mod h1:ly2RBz4mnz1yeuVbQA/VFwGjK3mnHGRj1JuoG336Bis= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/octosql/values.go b/octosql/values.go index 79e7d281..31700165 100644 --- a/octosql/values.go +++ b/octosql/values.go @@ -1,7 +1,10 @@ package octosql import ( + "encoding/binary" "fmt" + "hash" + "math" "strings" "time" ) @@ -236,6 +239,63 @@ func (value Value) Compare(other Value) int { } } +func (value Value) Hash(hash hash.Hash64) { + switch value.TypeID { + case TypeIDNull: + hash.Write([]byte{0}) + + case TypeIDInt: + var data [8]byte + binary.BigEndian.PutUint64(data[:], uint64(value.Int)) + hash.Write(data[:]) + + case TypeIDFloat: + var data [8]byte + binary.BigEndian.PutUint64(data[:], math.Float64bits(value.Float)) + hash.Write(data[:]) + + case TypeIDBoolean: + if value.Boolean { + hash.Write([]byte{1}) + } else { + hash.Write([]byte{0}) + } + + case TypeIDString: + hash.Write([]byte(value.Str)) + + case TypeIDTime: + var data [8]byte + binary.BigEndian.PutUint64(data[:], uint64(value.Time.UnixNano())) + hash.Write(data[:]) + + case TypeIDDuration: + var data [8]byte + binary.BigEndian.PutUint64(data[:], uint64(value.Duration)) + hash.Write(data[:]) + + case TypeIDList: + for i := range value.List { + value.List[i].Hash(hash) + } + + case TypeIDStruct: + for i := range value.List { + value.Struct[i].Hash(hash) + } + + case TypeIDTuple: + for i := range value.List { + value.Tuple[i].Hash(hash) + } + + case TypeIDUnion: + panic("can't have union type as concrete value instance") + default: + panic("impossible, type switch bug") + } +} + func (value Value) Equal(other Value) bool { if value.TypeID == TypeIDNull && other.TypeID == TypeIDNull { return false