Skip to content

Commit

Permalink
change next to return holder directly
Browse files Browse the repository at this point in the history
Signed-off-by: arenatlx <314806019@qq.com>
  • Loading branch information
AilinKid committed Dec 2, 2024
1 parent b736bee commit 4d96d1b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 67 deletions.
148 changes: 83 additions & 65 deletions pkg/planner/cascades/rule/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ func TestBinderSuccess(t *testing.T) {
// bind the pattern to the memo.
rootGE := mm.GetRootGroup().GetLogicalExpressions().Back().Value.(*memo.GroupExpression)
binder := NewBinder(pa, rootGE)
require.True(t, binder.Next())
require.True(t, binder.holder.(*memo.GroupExpression).LogicalPlan == join)
require.True(t, binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan == t1)
require.True(t, binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan == t2)
holder := binder.Next()
require.NotNil(t, holder)
require.True(t, holder.(*memo.GroupExpression).LogicalPlan == join)
require.True(t, holder.Children()[0].(*memo.GroupExpression).LogicalPlan == t1)
require.True(t, holder.Children()[1].(*memo.GroupExpression).LogicalPlan == t2)
}

func TestBinderFail(t *testing.T) {
Expand All @@ -82,7 +83,8 @@ func TestBinderFail(t *testing.T) {
b := bytes.Buffer{}
buf := util.NewStrBuffer(&b)
binder.bsw = buf
require.False(t, binder.Next())
holder := binder.Next()
require.Nil(t, holder)
buf.Flush()
require.Equal(t, b.String(), "GE:DataSource_1{}\n")

Expand All @@ -99,7 +101,8 @@ func TestBinderFail(t *testing.T) {
b.Reset()
buf = util.NewStrBuffer(&b)
binder.bsw = buf
require.False(t, binder.Next())
holder = binder.Next()
require.Nil(t, holder)
buf.Flush()
require.Equal(t, b.String(), "")

Expand All @@ -111,7 +114,8 @@ func TestBinderFail(t *testing.T) {
b.Reset()
buf = util.NewStrBuffer(&b)
binder.bsw = buf
require.False(t, binder.Next())
holder = binder.Next()
require.Nil(t, holder)
buf.Flush()
require.Equal(t, b.String(), "GE:Limit_4{inputs:1}\n")
}
Expand All @@ -131,7 +135,8 @@ func TestBinderTopNode(t *testing.T) {
// single level pattern, no children.
pa := pattern.NewPattern(pattern.OperandJoin, pattern.EngineAll)
binder := NewBinder(pa, mm.GetRootGroup().GetLogicalExpressions().Back().Value.(*memo.GroupExpression))
require.True(t, binder.Next())
holder := binder.Next()
require.NotNil(t, holder)
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
}

Expand All @@ -146,8 +151,9 @@ func TestBinderOneNode(t *testing.T) {

pa := pattern.NewPattern(pattern.OperandJoin, pattern.EngineAll)
binder := NewBinder(pa, mm.GetRootGroup().GetLogicalExpressions().Back().Value.(*memo.GroupExpression))
require.True(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
holder := binder.Next()
require.NotNil(t, holder)
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
}

func TestBinderSubTreeMatch(t *testing.T) {
Expand Down Expand Up @@ -176,18 +182,21 @@ func TestBinderSubTreeMatch(t *testing.T) {
// bind the pattern to the memo.
rootGE := mm.GetRootGroup().GetLogicalExpressions().Back().Value.(*memo.GroupExpression)
binder := NewBinder(pa, rootGE)
require.True(t, binder.Next())
require.True(t, binder.holder.(*memo.GroupExpression).LogicalPlan == join3)
require.True(t, binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan == join1)
require.True(t, binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan == join2)
require.False(t, binder.Next())
holder := binder.Next()
require.NotNil(t, holder)
require.True(t, holder.(*memo.GroupExpression).LogicalPlan == join3)
require.True(t, holder.Children()[0].(*memo.GroupExpression).LogicalPlan == join1)
require.True(t, holder.Children()[1].(*memo.GroupExpression).LogicalPlan == join2)
holder = binder.Next()
require.Nil(t, holder)

pa2 := pattern.NewPattern(pattern.OperandJoin, pattern.EngineAll)
pa2.SetChildren(pattern.NewPattern(pattern.OperandDataSource, pattern.EngineAll), pattern.NewPattern(pattern.OperandDataSource, pattern.EngineAll))
binder = NewBinder(pa2, rootGE)
// we couldn't bind the pattern to the subtree of join3, because the root group expression is pinned.
// the top-down iteration across all the tree nodes is the responsibility of the caller.
require.False(t, binder.Next())
holder = binder.Next()
require.Nil(t, holder)
}

func TestBinderMultiNext(t *testing.T) {
Expand Down Expand Up @@ -222,49 +231,53 @@ func TestBinderMultiNext(t *testing.T) {
buf := util.NewStrBuffer(&b)
binder.bsw = buf

require.True(t, binder.Next())
holder := binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.True(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

holder = binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t4", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.True(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t4", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

holder = binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.True(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

holder = binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t4", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t4", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

buf.Flush()
// every time when next call done, the save stack info should be next iteration starting point.
Expand Down Expand Up @@ -315,29 +328,32 @@ func TestBinderAny(t *testing.T) {
buf := util.NewStrBuffer(&b)
binder.bsw = buf

require.True(t, binder.Next())
holder := binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.True(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

holder = binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t3", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.False(t, binder.Next())
holder = binder.Next()
require.Nil(t, holder)

buf.Flush()
// every time when next call done, the save stack info should be next iteration starting point.
Expand Down Expand Up @@ -396,18 +412,20 @@ func TestBinderMultiAny(t *testing.T) {
buf := util.NewStrBuffer(&b)
binder.bsw = buf

require.True(t, binder.Next())
holder := binder.Next()
require.NotNil(t, holder)
// G1
// / \
// G2{t1,t3} G3{t2,t4}
// ▴ ▴
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(binder.holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", binder.holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", binder.holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

require.False(t, binder.Next())
require.Equal(t, pattern.OperandJoin, pattern.GetOperand(holder.(*memo.GroupExpression).LogicalPlan))
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[0].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t1", holder.Children()[0].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)
require.Equal(t, pattern.OperandDataSource, pattern.GetOperand(holder.Children()[1].(*memo.GroupExpression).LogicalPlan))
require.Equal(t, "t2", holder.Children()[1].(*memo.GroupExpression).LogicalPlan.(*logicalop.DataSource).TableAsName.L)

holder = binder.Next()
require.Nil(t, holder)

buf.Flush()
// state1:
Expand Down
4 changes: 2 additions & 2 deletions pkg/planner/cascades/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type Rule interface {
ID() uint

// String implements the fmt.Stringer interface, used for rule tracing process.
String(writer util.IBufStrWriter)
String(writer util.StrBufferWriter)

// Pattern return the initialized pattern of a specific rule when it created.
Pattern() *pattern.Pattern
Expand Down Expand Up @@ -67,7 +67,7 @@ func (*BaseRule) ID() uint {
}

// String implements Rule interface
func (r *BaseRule) String(writer util.IBufStrWriter) {
func (r *BaseRule) String(writer util.StrBufferWriter) {
writer.WriteString(r.tp.String())
}

Expand Down

0 comments on commit 4d96d1b

Please sign in to comment.