diff --git a/DEPS.bzl b/DEPS.bzl index e826ae2f62c06..2bfb0c788ff82 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -6409,6 +6409,19 @@ def go_deps(): "https://storage.googleapis.com/pingcapmirror/gomod/github.com/segmentio/asm/com_github_segmentio_asm-v1.2.0.zip", ], ) + go_repository( + name = "com_github_segmentio_fasthash", + build_file_proto_mode = "disable_global", + importpath = "github.com/segmentio/fasthash", + sha256 = "fe6b87a2841eac3670539d105692d39f67155955202145dc78f3a29c866b8cb6", + strip_prefix = "github.com/segmentio/fasthash@v1.0.3", + urls = [ + "http://bazel-cache.pingcap.net:8080/gomod/github.com/segmentio/fasthash/com_github_segmentio_fasthash-v1.0.3.zip", + "http://ats.apps.svc/gomod/github.com/segmentio/fasthash/com_github_segmentio_fasthash-v1.0.3.zip", + "https://cache.hawkingrei.com/gomod/github.com/segmentio/fasthash/com_github_segmentio_fasthash-v1.0.3.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/segmentio/fasthash/com_github_segmentio_fasthash-v1.0.3.zip", + ], + ) go_repository( name = "com_github_sergi_go_diff", build_file_proto_mode = "disable_global", @@ -7475,6 +7488,19 @@ def go_deps(): "https://storage.googleapis.com/pingcapmirror/gomod/github.com/zeebo/xxh3/com_github_zeebo_xxh3-v1.0.2.zip", ], ) + go_repository( + name = "com_github_zyedidia_generic", + build_file_proto_mode = "disable_global", + importpath = "github.com/zyedidia/generic", + sha256 = "21f980420a46e0f6ed564dd9658ddab9991cef8eca32804a956fb65f7f9d4c31", + strip_prefix = "github.com/zyedidia/generic@v1.2.1", + urls = [ + "http://bazel-cache.pingcap.net:8080/gomod/github.com/zyedidia/generic/com_github_zyedidia_generic-v1.2.1.zip", + "http://ats.apps.svc/gomod/github.com/zyedidia/generic/com_github_zyedidia_generic-v1.2.1.zip", + "https://cache.hawkingrei.com/gomod/github.com/zyedidia/generic/com_github_zyedidia_generic-v1.2.1.zip", + "https://storage.googleapis.com/pingcapmirror/gomod/github.com/zyedidia/generic/com_github_zyedidia_generic-v1.2.1.zip", + ], + ) go_repository( name = "com_gitlab_bosi_decorder", build_file_proto_mode = "disable_global", diff --git a/go.mod b/go.mod index 8a129b2d22b7a..b0a435d89088e 100644 --- a/go.mod +++ b/go.mod @@ -118,6 +118,7 @@ require ( github.com/wangjohn/quickselect v0.0.0-20161129230411-ed8402a42d5f github.com/xitongsys/parquet-go v1.6.3-0.20240520233950-75e935fc3e17 github.com/xitongsys/parquet-go-source v0.0.0-20200817004010-026bad9b25d0 + github.com/zyedidia/generic v1.2.1 go.etcd.io/etcd/api/v3 v3.5.12 go.etcd.io/etcd/client/pkg/v3 v3.5.12 go.etcd.io/etcd/client/v3 v3.5.12 @@ -168,6 +169,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/qri-io/jsonpointer v0.1.1 // indirect + github.com/segmentio/fasthash v1.0.3 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect ) diff --git a/go.sum b/go.sum index bc69ac7f3fcd1..74e7cc5c10227 100644 --- a/go.sum +++ b/go.sum @@ -743,6 +743,8 @@ github.com/sasha-s/go-deadlock v0.3.5 h1:tNCOEEDG6tBqrNDOX35j/7hL5FcFViG6awUGROb github.com/sasha-s/go-deadlock v0.3.5/go.mod h1:bugP6EGbdGYObIlx7pUZtWqlvo8k9H6vCBBsiChJQ5U= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -872,6 +874,8 @@ github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +github.com/zyedidia/generic v1.2.1 h1:Zv5KS/N2m0XZZiuLS82qheRG4X1o5gsWreGb0hR7XDc= +github.com/zyedidia/generic v1.2.1/go.mod h1:ly2RBz4mnz1yeuVbQA/VFwGjK3mnHGRj1JuoG336Bis= go.einride.tech/aip v0.66.0 h1:XfV+NQX6L7EOYK11yoHHFtndeaWh3KbD9/cN/6iWEt8= go.einride.tech/aip v0.66.0/go.mod h1:qAhMsfT7plxBX+Oy7Huol6YUvZ0ZzdUz26yZsQwfl1M= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= diff --git a/pkg/planner/cascades/memo/BUILD.bazel b/pkg/planner/cascades/memo/BUILD.bazel index 8fae94b19bb29..62fdf0ec0f6b2 100644 --- a/pkg/planner/cascades/memo/BUILD.bazel +++ b/pkg/planner/cascades/memo/BUILD.bazel @@ -20,6 +20,7 @@ go_library( "//pkg/util/intest", "@com_github_bits_and_blooms_bitset//:bitset", "@com_github_pingcap_failpoint//:failpoint", + "@com_github_zyedidia_generic//hashmap", ], ) @@ -34,7 +35,7 @@ go_test( ], embed = [":memo"], flaky = True, - shard_count = 5, + shard_count = 7, deps = [ "//pkg/expression", "//pkg/planner/cascades/base", @@ -43,6 +44,7 @@ go_test( "//pkg/util/mock", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", + "@com_github_zyedidia_generic//hashmap", "@org_uber_go_goleak//:goleak", ], ) diff --git a/pkg/planner/cascades/memo/group.go b/pkg/planner/cascades/memo/group.go index c9a7f4a3286d1..0640d935be6e7 100644 --- a/pkg/planner/cascades/memo/group.go +++ b/pkg/planner/cascades/memo/group.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/planner/cascades/util" "github.com/pingcap/tidb/pkg/planner/property" "github.com/pingcap/tidb/pkg/util/intest" + "github.com/zyedidia/generic/hashmap" ) var _ base.HashEquals = &Group{} @@ -41,7 +42,7 @@ type Group struct { Operand2FirstExpr map[pattern.Operand]*list.Element // hash2GroupExpr is used to de-duplication in the list. - hash2GroupExpr map[uint64]*list.Element + hash2GroupExpr *hashmap.Map[*GroupExpression, *list.Element] // logicalProp indicates the logical property. logicalProp *property.LogicalProperty @@ -74,22 +75,13 @@ func (g *Group) Equals(other any) bool { // ******************************************* end of HashEqual methods ******************************************* -// Exists checks whether a Group expression existed in a Group. -func (g *Group) Exists(e *GroupExpression) bool { - one, ok := g.hash2GroupExpr[e.GetHash64()] - if !ok { - return false - } - return one.Value.(*GroupExpression).Equals(e) -} - // Insert adds a GroupExpression to the Group. func (g *Group) Insert(e *GroupExpression) bool { if e == nil { return false } // GroupExpressions hash should be initialized within Init(xxx) method. - if g.Exists(e) { + if _, ok := g.hash2GroupExpr.Get(e); ok { return false } operand := pattern.GetOperand(e.LogicalPlan) @@ -103,7 +95,7 @@ func (g *Group) Insert(e *GroupExpression) bool { newEquiv = g.logicalExpressions.PushBack(e) g.Operand2FirstExpr[operand] = newEquiv } - g.hash2GroupExpr[e.GetHash64()] = newEquiv + g.hash2GroupExpr.Put(e, newEquiv) e.group = g return true } @@ -174,9 +166,17 @@ func (g *Group) ForEachGE(f func(ge *GroupExpression) bool) { func NewGroup(prop *property.LogicalProperty) *Group { g := &Group{ logicalExpressions: list.New(), - hash2GroupExpr: make(map[uint64]*list.Element), Operand2FirstExpr: make(map[pattern.Operand]*list.Element), logicalProp: prop, + hash2GroupExpr: hashmap.New[*GroupExpression, *list.Element]( + 4, + func(a, b *GroupExpression) bool { + return a.Equals(b) + }, + func(t *GroupExpression) uint64 { + return t.GetHash64() + }, + ), } return g } diff --git a/pkg/planner/cascades/memo/group_and_expr_test.go b/pkg/planner/cascades/memo/group_and_expr_test.go index 1f3d6b1d07e1e..ba288b4397b31 100644 --- a/pkg/planner/cascades/memo/group_and_expr_test.go +++ b/pkg/planner/cascades/memo/group_and_expr_test.go @@ -21,8 +21,75 @@ import ( "github.com/pingcap/tidb/pkg/planner/cascades/base" "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" "github.com/stretchr/testify/require" + "github.com/zyedidia/generic/hashmap" ) +func TestRawHashMap(t *testing.T) { + type A struct { + a uint64 + s string + } + hash2GroupExpr := hashmap.New[*A, *A]( + 4, + func(a, b *A) bool { + return a.a == b.a && a.s == b.s + }, + func(t *A) uint64 { + return t.a + }) + a1 := &A{1, "1"} + hash2GroupExpr.Put(a1, a1) + res, ok := hash2GroupExpr.Get(a1) + require.True(t, ok) + require.Equal(t, res.a, uint64(1)) + require.Equal(t, res.s, "1") + + a2 := &A{1, "2"} + hash2GroupExpr.Put(a2, a2) + require.Equal(t, hash2GroupExpr.Size(), 2) + + res, ok = hash2GroupExpr.Get(a2) + require.True(t, ok) + require.Equal(t, res.a, uint64(1)) + require.Equal(t, res.s, "2") +} + +func TestGroupExpressionHashCollision(t *testing.T) { + child1 := &Group{groupID: 1} + child2 := &Group{groupID: 2} + a := &GroupExpression{ + Inputs: []*Group{child1, child2}, + LogicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}}, + } + b := &GroupExpression{ + // root group should change the hash. + Inputs: []*Group{child2, child1}, + LogicalPlan: &logicalop.LogicalProjection{Exprs: []expression.Expression{expression.NewOne()}}, + } + // manually set this two group expression's hash64 to be the same to mock hash collision while equals is not. + a.hash64 = 1 + b.hash64 = 1 + root := NewGroup(nil) + root.groupID = 5 + require.True(t, root.Insert(a)) + require.True(t, root.Insert(b)) + require.Equal(t, root.logicalExpressions.Len(), 2) + + res, ok := root.hash2GroupExpr.Get(a) + require.True(t, ok) + require.Equal(t, res.Value.(*GroupExpression).hash64, uint64(1)) + require.Equal(t, res.Value.(*GroupExpression).group.groupID, GroupID(5)) + require.Equal(t, res.Value.(*GroupExpression).Inputs[0].groupID, GroupID(1)) + require.Equal(t, res.Value.(*GroupExpression).Inputs[1].groupID, GroupID(2)) + + res, ok = root.hash2GroupExpr.Get(b) + require.True(t, ok) + require.Equal(t, res.Value.(*GroupExpression).hash64, uint64(1)) + require.Equal(t, res.Value.(*GroupExpression).group.groupID, GroupID(5)) + require.Equal(t, res.Value.(*GroupExpression).Inputs[0].groupID, GroupID(2)) + require.Equal(t, res.Value.(*GroupExpression).Inputs[1].groupID, GroupID(1)) +} + func TestGroupHashEquals(t *testing.T) { hasher1 := base.NewHashEqualer() hasher2 := base.NewHashEqualer()