Skip to content

Commit

Permalink
cherry pick of 16174 (#5378)
Browse files Browse the repository at this point in the history
  • Loading branch information
planetscale-actions-bot authored Jun 14, 2024
1 parent 8aace46 commit 2fffd03
Show file tree
Hide file tree
Showing 17 changed files with 172 additions and 21 deletions.
14 changes: 14 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,20 @@ func TestColumnAliases(t *testing.T) {
mcmp.ExecWithColumnCompare(`select a as k from (select count(*) as a from t1) t`)
}

func TestHandleNullableColumn(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 21, "vtgate")
require.NoError(t,
utils.WaitForAuthoritative(t, keyspaceName, "tbl", clusterInstance.VtgateProcess.ReadVSchema))
mcmp, closer := start(t)
defer closer()

mcmp.Exec("insert into t1(id1, id2) values (0,0), (1,1), (2,2)")
mcmp.Exec("insert into tbl(id, unq_col, nonunq_col) values (0,0,0), (1,1,6)")
// This query tests that we handle nullable columns correctly
// tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL
mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`)
}

func TestEnumSetVals(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ create table tbl
(
id bigint,
unq_col bigint,
nonunq_col bigint,
nonunq_col bigint not null,
primary key (id),
unique (unq_col)
) Engine = InnoDB;
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ func (t *Type) Nullable() bool {
return true // nullable by default for unknown types
}

func (t *Type) SetNullability(n bool) {
t.nullable = n
}

func (t *Type) Values() *EnumSetValues {
return t.values
}
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggrega
}

for _, groupBy := range op.Grouping {
typ, _ := ctx.SemTable.TypeForExpr(groupBy.Inner)
typ, _ := ctx.TypeForExpr(groupBy.Inner)
groupByKeys = append(groupByKeys, &engine.GroupByParams{
KeyCol: groupBy.ColOffset,
WeightStringCol: groupBy.WSOffset,
Expand Down Expand Up @@ -372,7 +372,7 @@ func createMemorySort(ctx *plancontext.PlanningContext, src engine.Primitive, or
}

for idx, order := range ordering.Order {
typ, _ := ctx.SemTable.TypeForExpr(order.SimplifiedExpr)
typ, _ := ctx.TypeForExpr(order.SimplifiedExpr)
prim.OrderBy = append(prim.OrderBy, evalengine.OrderByParams{
Col: ordering.Offset[idx],
WeightStringCol: ordering.WOffset[idx],
Expand Down Expand Up @@ -438,7 +438,7 @@ func getEvalEngineExpr(ctx *plancontext.PlanningContext, pe *operators.ProjExpr)
case *operators.EvalEngine:
return e.EExpr, nil
case operators.Offset:
typ, _ := ctx.SemTable.TypeForExpr(pe.EvalExpr)
typ, _ := ctx.TypeForExpr(pe.EvalExpr)
return evalengine.NewColumn(int(e), typ, pe.EvalExpr), nil
default:
return nil, vterrors.VT13001("project not planned for: %s", pe.String())
Expand Down Expand Up @@ -590,7 +590,7 @@ func buildRoutePrimitive(ctx *plancontext.PlanningContext, op *operators.Route,
}

for _, order := range op.Ordering {
typ, _ := ctx.SemTable.TypeForExpr(order.AST)
typ, _ := ctx.TypeForExpr(order.AST)
eroute.OrderBy = append(eroute.OrderBy, evalengine.OrderByParams{
Col: order.Offset,
WeightStringCol: order.WOffset,
Expand Down Expand Up @@ -907,11 +907,11 @@ func transformHashJoin(ctx *plancontext.PlanningContext, op *operators.HashJoin)

var missingTypes []string

ltyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].LHS)
ltyp, found := ctx.TypeForExpr(op.JoinComparisons[0].LHS)
if !found {
missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].LHS))
}
rtyp, found := ctx.SemTable.TypeForExpr(op.JoinComparisons[0].RHS)
rtyp, found := ctx.TypeForExpr(op.JoinComparisons[0].RHS)
if !found {
missingTypes = append(missingTypes, sqlparser.String(op.JoinComparisons[0].RHS))
}
Expand Down Expand Up @@ -949,7 +949,7 @@ func transformVindexPlan(ctx *plancontext.PlanningContext, op *operators.Vindex)

expr, err := evalengine.Translate(op.Value, &evalengine.Config{
Collation: ctx.SemTable.Collation,
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Environment: ctx.VSchema.Environment(),
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (d *Distinct) planOffsets(ctx *plancontext.PlanningContext) Operator {
offset := d.Source.AddWSColumn(ctx, idx, false)
wsCol = &offset
}
typ, _ := ctx.SemTable.TypeForExpr(e)
typ, _ := ctx.TypeForExpr(e)
d.Columns = append(d.Columns, engine.CheckCol{
Col: idx,
WsCol: wsCol,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (f *Filter) Compact(*plancontext.PlanningContext) (Operator, *ApplyResult)

func (f *Filter) planOffsets(ctx *plancontext.PlanningContext) Operator {
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp

rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr)
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down Expand Up @@ -458,7 +458,7 @@ func (hj *HashJoin) addSingleSidedColumn(

rewrittenExpr := sqlparser.CopyOnRewrite(in, pre, r.post, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr)
cfg := &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func insertRowsPlan(ctx *plancontext.PlanningContext, insOp *Insert, ins *sqlpar
colNum, _ := findOrAddColumn(ins, col)
for rowNum, row := range rows {
innerpv, err := evalengine.Translate(row[colNum], &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down Expand Up @@ -637,7 +637,7 @@ func modifyForAutoinc(ctx *plancontext.PlanningContext, ins *sqlparser.Insert, v
}
var err error
gen.Values, err = evalengine.Translate(autoIncValues, &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func createLeftOuterJoin(ctx *plancontext.PlanningContext, join *sqlparser.JoinT

joinOp := &Join{LHS: lhs, RHS: rhs, JoinType: join.Join}

// mark the RHS as outer tables so we know which columns are nullable
ctx.OuterTables = ctx.OuterTables.Merge(TableID(rhs))

// for outer joins we have to be careful with the predicates we use
var op Operator
subq, _ := getSubQuery(join.Condition.On)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ func (p *Projection) planOffsets(ctx *plancontext.PlanningContext) Operator {

// for everything else, we'll turn to the evalengine
eexpr, err := evalengine.Translate(rewritten, &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (aggr Aggr) GetTypeCollation(ctx *plancontext.PlanningContext) evalengine.T
}
switch aggr.OpCode {
case opcode.AggregateMin, opcode.AggregateMax, opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
typ, _ := ctx.SemTable.TypeForExpr(aggr.Func.GetArg())
typ, _ := ctx.TypeForExpr(aggr.Func.GetArg())
return typ

}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/sharded_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func (tr *ShardedRouting) planCompositeInOpArg(
Key: right.String(),
Index: idx,
}
if typ, found := ctx.SemTable.TypeForExpr(col); found {
if typ, found := ctx.TypeForExpr(col); found {
value.Type = typ.Type()
value.Collation = typ.Collation()
}
Expand Down Expand Up @@ -687,7 +687,7 @@ func makeEvalEngineExpr(ctx *plancontext.PlanningContext, n sqlparser.Expr) eval
for _, expr := range ctx.SemTable.GetExprAndEqualities(n) {
ee, _ := evalengine.Translate(expr, &evalengine.Config{
Collation: ctx.SemTable.Collation,
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Environment: ctx.VSchema.Environment(),
})
if ee != nil {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ func createMergedUnion(
continue
}
deps = deps.Merge(ctx.SemTable.RecursiveDeps(rae.Expr))
rt, foundR := ctx.SemTable.TypeForExpr(rae.Expr)
lt, foundL := ctx.SemTable.TypeForExpr(lae.Expr)
rt, foundR := ctx.TypeForExpr(rae.Expr)
lt, foundL := ctx.TypeForExpr(lae.Expr)
if foundR && foundL {
collations := ctx.VSchema.Environment().CollationEnv()
var typer evalengine.TypeAggregator
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ func createAssignmentExpressions(
}
found = true
pv, err := evalengine.Translate(assignment.Expr.EvalExpr, &evalengine.Config{
ResolveType: ctx.SemTable.TypeForExpr,
ResolveType: ctx.TypeForExpr,
Collation: ctx.SemTable.Collation,
Environment: ctx.VSchema.Environment(),
})
Expand Down
21 changes: 21 additions & 0 deletions go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

Expand Down Expand Up @@ -57,6 +58,10 @@ type PlanningContext struct {

// Statement contains the originally parsed statement
Statement sqlparser.Statement

// OuterTables contains the tables that are outer to the current query
// Used to set the nullable flag on the columns
OuterTables semantics.TableSet
}

// CreatePlanningContext initializes a new PlanningContext with the given parameters.
Expand Down Expand Up @@ -201,3 +206,19 @@ func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, t
}
return modifiedExpr
}

// TypeForExpr returns the type of the given expression, with nullable set if the expression is from an outer table.
func (ctx *PlanningContext) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
t, found := ctx.SemTable.TypeForExpr(e)
if !found {
return t, found
}
deps := ctx.SemTable.RecursiveDeps(e)
// If the expression is from an outer table, it should be nullable
// There are some exceptions to this, where an expression depending on the outer side
// will never return NULL, but it's better to be conservative here.
if deps.IsOverlapping(ctx.OuterTables) {
t.SetNullability(true)
}
return t, true
}
108 changes: 108 additions & 0 deletions go/vt/vtgate/planbuilder/plancontext/planning_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2024 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package plancontext

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

func TestOuterTableNullability(t *testing.T) {
// Tests that columns from outer tables are nullable,
// even though the semantic state says that they are not nullable.
// This is because the outer table may not have a matching row.
// All columns are marked as NOT NULL in the schema.
query := "select * from t1 left join t2 on t1.a = t2.a where t1.a+t2.a/abs(t2.boing)"
ctx, columns := prepareContextAndFindColumns(t, query)

// Check if the columns are correctly marked as nullable.
for _, col := range columns {
colName := "column: " + sqlparser.String(col)
t.Run(colName, func(t *testing.T) {
// Extract the column type from the context and the semantic state.
// The context should mark the column as nullable.
ctxType, found := ctx.TypeForExpr(col)
require.True(t, found, colName)
stType, found := ctx.SemTable.TypeForExpr(col)
require.True(t, found, colName)
ctxNullable := ctxType.Nullable()
stNullable := stType.Nullable()

switch col.Qualifier.Name.String() {
case "t1":
assert.False(t, ctxNullable, colName)
assert.False(t, stNullable, colName)
case "t2":
assert.True(t, ctxNullable, colName)

// The semantic state says that the column is not nullable. Don't trust it.
assert.False(t, stNullable, colName)
}
})
}
}

func prepareContextAndFindColumns(t *testing.T, query string) (ctx *PlanningContext, columns []*sqlparser.ColName) {
parser := sqlparser.NewTestParser()
ast, err := parser.Parse(query)
require.NoError(t, err)
semTable := semantics.EmptySemTable()
t1 := semTable.NewTableId()
t2 := semTable.NewTableId()
stmt := ast.(*sqlparser.Select)
expr := stmt.Where.Expr

// Instead of using the semantic analysis, we manually set the types for the columns.
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
col, ok := node.(*sqlparser.ColName)
if !ok {
return true, nil
}

switch col.Qualifier.Name.String() {
case "t1":
semTable.Recursive[col] = t1
case "t2":
semTable.Recursive[col] = t2
}

intNotNull := evalengine.NewType(sqltypes.Int64, collations.Unknown)
intNotNull.SetNullability(false)
semTable.ExprTypes[col] = intNotNull
columns = append(columns, col)
return false, nil
}, nil, expr)

ctx = &PlanningContext{
SemTable: semTable,
joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{},
skipPredicates: map[sqlparser.Expr]any{},
ReservedArguments: map[sqlparser.Expr]string{},
Statement: stmt,
OuterTables: t2, // t2 is the outer table.
}
return
}
1 change: 1 addition & 0 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.Sel
}

// TypeForExpr returns the type of expressions in the query
// Note that PlanningContext has the same method, and you should use that if you have a PlanningContext
func (st *SemTable) TypeForExpr(e sqlparser.Expr) (evalengine.Type, bool) {
if typ, found := st.ExprTypes[e]; found {
return typ, true
Expand Down

0 comments on commit 2fffd03

Please sign in to comment.