diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 1d7ab540c12..79d9709c479 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -718,6 +718,15 @@ func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr } } +// NewCaseExpr makes a new CaseExpr +func NewCaseExpr(expr Expr, whens []*When, elseExpr Expr) *CaseExpr { + return &CaseExpr{ + Expr: expr, + Whens: whens, + Else: elseExpr, + } +} + // NewLimit makes a new Limit func NewLimit(offset, rowCount int) *Limit { return &Limit{ diff --git a/go/vt/vtgate/simplifier/expression_simplifier.go b/go/vt/vtgate/simplifier/expression_simplifier.go index 279cb1ac7dd..194e5422c04 100644 --- a/go/vt/vtgate/simplifier/expression_simplifier.go +++ b/go/vt/vtgate/simplifier/expression_simplifier.go @@ -142,6 +142,10 @@ func (s *shrinker) Next() sqlparser.Expr { func (s *shrinker) fillQueue() bool { before := len(s.queue) switch e := s.orig.(type) { + case *sqlparser.AndExpr: + s.queue = append(s.queue, e.Left, e.Right) + case *sqlparser.OrExpr: + s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.ComparisonExpr: s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.BinaryExpr: @@ -231,6 +235,28 @@ func (s *shrinker) fillQueue() bool { case *sqlparser.ColName: // we can try to replace the column with a literal value s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")} + case *sqlparser.CaseExpr: + s.queue = append(s.queue, e.Expr, e.Else) + for _, when := range e.Whens { + s.queue = append(s.queue, when.Cond, when.Val) + } + + if len(e.Whens) > 1 { + for i := range e.Whens { + whensCopy := sqlparser.CloneSliceOfRefOfWhen(e.Whens) + // replace ith element with last element, then truncate last element + whensCopy[i] = whensCopy[len(whensCopy)-1] + whensCopy = whensCopy[:len(whensCopy)-1] + s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, whensCopy, e.Else)) + } + } + + if e.Else != nil { + s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, e.Whens, nil)) + } + if e.Expr != nil { + s.queue = append(s.queue, sqlparser.NewCaseExpr(nil, e.Whens, e.Else)) + } default: return false } diff --git a/go/vt/vtgate/simplifier/simplifier.go b/go/vt/vtgate/simplifier/simplifier.go index ef7be4e30e5..19d8e92d56b 100644 --- a/go/vt/vtgate/simplifier/simplifier.go +++ b/go/vt/vtgate/simplifier/simplifier.go @@ -70,27 +70,28 @@ func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.Se if test(in) { log.Errorf("removed expression: %s", sqlparser.String(cursor.expr)) simplified = true - return false + // initially return false, but that made the rewriter prematurely abort sometimes + return true } cursor.restore() } // ok, we seem to need this expression. let's see if we can find a simpler version - s := &shrinker{orig: cursor.expr} - newExpr := s.Next() - for newExpr != nil { - cursor.replace(newExpr) + newExpr := SimplifyExpr(cursor.expr, func(expr sqlparser.Expr) bool { + cursor.replace(expr) if test(in) { - log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(newExpr)) + log.Errorf("simplified expression: %s -> %s", sqlparser.String(cursor.expr), sqlparser.String(expr)) + cursor.restore() simplified = true - return false + return true } - newExpr = s.Next() - } - // if we get here, we failed to simplify this expression, - // so we put back in the original expression - cursor.restore() + cursor.restore() + return false + }) + + cursor.replace(newExpr) + // initially return false, but that made the rewriter prematurely abort sometimes return true })