Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: Add column nullability checking before "refine args" #20044

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ Projection_10 5.00 root Column#11
└─Selection_38 2.40 cop[tikv] eq(3, test.t.a)
└─IndexRangeScan_37 3.00 cop[tikv] table:t1, index:idx(b) range:[3,3], keep order:true
drop table if exists t;
create table t(a int unsigned);
create table t(a int unsigned not null);
explain select t.a = '123455' from t;
id estRows task access object operator info
Projection_3 10000.00 root eq(test.t.a, 123455)->Column#3
Expand Down Expand Up @@ -508,7 +508,7 @@ StreamAgg_22 1.00 root funcs:count(1)->Column#22
│ │ └─TableDual_39 0.00 root rows:0
│ └─Projection_40 0.01 root test.test01.stat_date, test.test01.show_date, test.test01.region_id
│ └─TableReader_43 0.01 root data:Selection_42
│ └─Selection_42 0.01 cop[tikv] eq(test.test01.period, 1), ge(test.test01.stat_date, 20191202), ge(test.test01.stat_date, 20191202), gt(cast(test.test01.registration_num), 0), le(test.test01.stat_date, 20191202), le(test.test01.stat_date, 20191202)
│ └─Selection_42 0.01 cop[tikv] eq(test.test01.period, 1), ge(cast(test.test01.stat_date), 2.0191202e+07), ge(test.test01.stat_date, 20191202), gt(cast(test.test01.registration_num), 0), le(cast(test.test01.stat_date), 2.0191202e+07), le(test.test01.stat_date, 20191202)
│ └─TableFullScan_41 10000.00 cop[tikv] table:test01 keep order:false, stats:pseudo
└─TableReader_29(Probe) 1.00 root data:TableRangeScan_28
└─TableRangeScan_28 1.00 cop[tikv] table:b range: decided by [Column#16], keep order:true, stats:pseudo
Expand Down
24 changes: 12 additions & 12 deletions cmd/explaintest/r/partition_pruning.result
Original file line number Diff line number Diff line change
Expand Up @@ -2261,33 +2261,33 @@ TableReader_7 10.00 root partition:p0 data:Selection_6
└─TableFullScan_5 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
explain select * from t1 where a=0xFE;
id estRows task access object operator info
TableReader_7 10.00 root partition:p2 data:Selection_6
└─Selection_6 10.00 cop[tikv] eq(test.t1.a, 254)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] eq(cast(test.t1.a), 254)
└─TableFullScan_5 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain select * from t2 where a=0xFE;
id estRows task access object operator info
TableReader_7 10.00 root partition:p2 data:Selection_6
└─Selection_6 10.00 cop[tikv] eq(test.t2.a, 254)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] eq(cast(test.t2.a), 254)
└─TableFullScan_5 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
explain select * from t1 where a > 0xFE AND a <= 0xFF;
id estRows task access object operator info
TableReader_7 250.00 root partition:dual data:Selection_6
└─Selection_6 250.00 cop[tikv] gt(test.t1.a, 254), le(test.t1.a, 255)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] gt(cast(test.t1.a), 254), le(cast(test.t1.a), 255)
└─TableFullScan_5 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain select * from t2 where a > 0xFE AND a <= 0xFF;
id estRows task access object operator info
TableReader_7 250.00 root partition:all data:Selection_6
└─Selection_6 250.00 cop[tikv] gt(test.t2.a, 254), le(test.t2.a, 255)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] gt(cast(test.t2.a), 254), le(cast(test.t2.a), 255)
└─TableFullScan_5 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
explain select * from t1 where a >= 0xFE AND a <= 0xFF;
id estRows task access object operator info
TableReader_7 250.00 root partition:p2 data:Selection_6
└─Selection_6 250.00 cop[tikv] ge(test.t1.a, 254), le(test.t1.a, 255)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] ge(cast(test.t1.a), 254), le(cast(test.t1.a), 255)
└─TableFullScan_5 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo
explain select * from t2 where a >= 0xFE AND a <= 0xFF;
id estRows task access object operator info
TableReader_7 250.00 root partition:all data:Selection_6
└─Selection_6 250.00 cop[tikv] ge(test.t2.a, 254), le(test.t2.a, 255)
TableReader_7 8000.00 root partition:all data:Selection_6
└─Selection_6 8000.00 cop[tikv] ge(cast(test.t2.a), 254), le(cast(test.t2.a), 255)
└─TableFullScan_5 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo
explain select * from t1 where a < 64 AND a >= 63;
id estRows task access object operator info
Expand Down
2 changes: 1 addition & 1 deletion cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a whe
explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t;

drop table if exists t;
create table t(a int unsigned);
create table t(a int unsigned not null);
explain select t.a = '123455' from t;
explain select t.a > '123455' from t;
explain select t.a != '123455' from t;
Expand Down
48 changes: 27 additions & 21 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,32 +1296,38 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express
isPositiveInfinite, isNegativeInfinite := false, false
// int non-constant [cmp] non-int constant
if arg0IsInt && !arg0IsCon && !arg1IsInt && arg1IsCon {
arg1, isExceptional = RefineComparedConstant(ctx, *arg0Type, arg1, c.op)
finalArg1 = arg1
if isExceptional && arg1.GetType().EvalType() == types.ETInt {
// Judge it is inf or -inf
// For int:
// inf: 01111111 & 1 == 1
// -inf: 10000000 & 1 == 0
// For uint:
// inf: 11111111 & 1 == 1
// -inf: 00000000 & 0 == 0
if arg1.Value.GetInt64()&1 == 1 {
isPositiveInfinite = true
} else {
isNegativeInfinite = true
arg0IsNotNull := mysql.HasNotNullFlag(args[0].GetType().Flag)
if arg0IsNotNull {
arg1, isExceptional = RefineComparedConstant(ctx, *arg0Type, arg1, c.op)
finalArg1 = arg1
if isExceptional && arg1.GetType().EvalType() == types.ETInt {
// Judge it is inf or -inf
// For int:
// inf: 01111111 & 1 == 1
// -inf: 10000000 & 1 == 0
// For uint:
// inf: 11111111 & 1 == 1
// -inf: 00000000 & 0 == 0
if arg1.Value.GetInt64()&1 == 1 {
isPositiveInfinite = true
} else {
isNegativeInfinite = true
}
}
}
}
// non-int constant [cmp] int non-constant
if arg1IsInt && !arg1IsCon && !arg0IsInt && arg0IsCon {
arg0, isExceptional = RefineComparedConstant(ctx, *arg1Type, arg0, symmetricOp[c.op])
finalArg0 = arg0
if isExceptional && arg0.GetType().EvalType() == types.ETInt {
if arg0.Value.GetInt64()&1 == 1 {
isNegativeInfinite = true
} else {
isPositiveInfinite = true
arg1IsNotNull := mysql.HasNotNullFlag(args[1].GetType().Flag)
if arg1IsNotNull {
arg0, isExceptional = RefineComparedConstant(ctx, *arg1Type, arg0, symmetricOp[c.op])
finalArg0 = arg0
if isExceptional && arg0.GetType().EvalType() == types.ETInt {
if arg0.Value.GetInt64()&1 == 1 {
isNegativeInfinite = true
} else {
isPositiveInfinite = true
}
}
}
}
Expand Down
54 changes: 53 additions & 1 deletion expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) {
tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong).build()
tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, mysql.NotNullFlag).build()
tests := []struct {
exprStr string
result string
Expand Down Expand Up @@ -78,6 +78,58 @@ func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) {
}
}

func (s *testEvaluatorSuite) TestCompareFunctionWithoutRefine(c *C) {
tblInfo := newTestTableBuilder("").add("a", mysql.TypeLong, 0).build()
tests := []struct {
exprStr string
result string
}{
{"a < '1.0'", "lt(cast(a, double BINARY), 1)"},
{"a <= '1.0'", "le(cast(a, double BINARY), 1)"},
{"a > '1'", "gt(cast(a, double BINARY), 1)"},
{"a >= '1'", "ge(cast(a, double BINARY), 1)"},
{"a = '1'", "eq(cast(a, double BINARY), 1)"},
{"a <=> '1'", "nulleq(cast(a, double BINARY), 1)"},
{"a != '1'", "ne(cast(a, double BINARY), 1)"},
{"a < '1.1'", "lt(cast(a, double BINARY), 1.1)"},
{"a <= '1.1'", "le(cast(a, double BINARY), 1.1)"},
{"a > 1.1", "gt(cast(a, decimal(20,0) BINARY), 1.1)"},
{"a >= '1.1'", "ge(cast(a, double BINARY), 1.1)"},
{"a = '1.1'", "eq(cast(a, double BINARY), 1.1)"},
{"a <=> '1.1'", "nulleq(cast(a, double BINARY), 1.1)"},
{"a != '1.1'", "ne(cast(a, double BINARY), 1.1)"},
{"'1' < a", "lt(1, cast(a, double BINARY))"},
{"'1' <= a", "le(1, cast(a, double BINARY))"},
{"'1' > a", "gt(1, cast(a, double BINARY))"},
{"'1' >= a", "ge(1, cast(a, double BINARY))"},
{"'1' = a", "eq(1, cast(a, double BINARY))"},
{"'1' <=> a", "nulleq(1, cast(a, double BINARY))"},
{"'1' != a", "ne(1, cast(a, double BINARY))"},
{"'1.1' < a", "lt(1.1, cast(a, double BINARY))"},
{"'1.1' <= a", "le(1.1, cast(a, double BINARY))"},
{"'1.1' > a", "gt(1.1, cast(a, double BINARY))"},
{"'1.1' >= a", "ge(1.1, cast(a, double BINARY))"},
{"'1.1' = a", "eq(1.1, cast(a, double BINARY))"},
{"'1.1' <=> a", "nulleq(1.1, cast(a, double BINARY))"},
{"'1.1' != a", "ne(1.1, cast(a, double BINARY))"},
{"'123456789123456711111189' = a", "eq(1.234567891234567e+23, cast(a, double BINARY))"},
{"123456789123456789.12345 = a", "eq(123456789123456789.12345, cast(a, decimal(20,0) BINARY))"},
{"123456789123456789123456789.12345 > a", "gt(123456789123456789123456789.12345, cast(a, decimal(20,0) BINARY))"},
{"-123456789123456789123456789.12345 > a", "gt(-123456789123456789123456789.12345, cast(a, decimal(20,0) BINARY))"},
{"123456789123456789123456789.12345 < a", "lt(123456789123456789123456789.12345, cast(a, decimal(20,0) BINARY))"},
{"-123456789123456789123456789.12345 < a", "lt(-123456789123456789123456789.12345, cast(a, decimal(20,0) BINARY))"},
{"'aaaa'=a", "eq(0, cast(a, double BINARY))"},
}
cols, names, err := ColumnInfos2ColumnsAndNames(s.ctx, model.NewCIStr(""), tblInfo.Name, tblInfo.Cols(), tblInfo)
c.Assert(err, IsNil)
schema := NewSchema(cols...)
for _, t := range tests {
f, err := ParseSimpleExprsWithNames(s.ctx, t.exprStr, schema, names)
c.Assert(err, IsNil)
c.Assert(f[0].String(), Equals, t.result)
}
}

func (s *testEvaluatorSuite) TestCompare(c *C) {
intVal, uintVal, realVal, stringVal, decimalVal := 1, uint64(1), 1.1, "123", types.NewDecFromFloatForTest(123.123)
timeVal := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 6)
Expand Down
7 changes: 5 additions & 2 deletions expression/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (s *testEvaluatorSuite) TestNewValuesFunc(c *C) {
}

func (s *testEvaluatorSuite) TestEvaluateExprWithNull(c *C) {
tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong).add("col1", mysql.TypeLonglong).build()
tblInfo := newTestTableBuilder("").add("col0", mysql.TypeLonglong, 0).add("col1", mysql.TypeLonglong, 0).build()
schema := tableInfoToSchemaForTest(tblInfo)
col0 := schema.Columns[0]
col1 := schema.Columns[1]
Expand Down Expand Up @@ -142,15 +142,17 @@ type testTableBuilder struct {
tableName string
columnNames []string
tps []byte
flags []uint
}

func newTestTableBuilder(tableName string) *testTableBuilder {
return &testTableBuilder{tableName: tableName}
}

func (builder *testTableBuilder) add(name string, tp byte) *testTableBuilder {
func (builder *testTableBuilder) add(name string, tp byte, flag uint) *testTableBuilder {
builder.columnNames = append(builder.columnNames, name)
builder.tps = append(builder.tps, tp)
builder.flags = append(builder.flags, flag)
return builder
}

Expand All @@ -165,6 +167,7 @@ func (builder *testTableBuilder) build() *model.TableInfo {
fieldType := types.NewFieldType(tp)
fieldType.Flen, fieldType.Decimal = mysql.GetDefaultFieldLengthAndDecimal(tp)
fieldType.Charset, fieldType.Collate = types.DefaultCharsetForType(tp)
fieldType.Flag = builder.flags[i]
ti.Columns = append(ti.Columns, &model.ColumnInfo{
ID: int64(i + 1),
Name: model.NewCIStr(colName),
Expand Down
2 changes: 1 addition & 1 deletion planner/core/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (s *testAnalyzeSuite) TestPreparedNullParam(c *C) {
testKit.MustExec("insert into t values (1), (2), (3)")

sql := "select * from t where id = ?"
best := "Dual"
best := "TableReader(Table(t)->Sel([eq(cast(test.t.id, double BINARY), <nil>)]))"

ctx := testKit.Se.(sessionctx.Context)
stmts, err := session.Parse(ctx, sql)
Expand Down
27 changes: 27 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,30 @@ func (s *testExpressionRewriterSuite) TestPatternLikeToExpression(c *C) {
tk.MustQuery("select 0 like '0';").Check(testkit.Rows("1"))
tk.MustQuery("select 0.00 like '0.00';").Check(testkit.Rows("1"))
}

// TestIssue16788 contains tests for https://github.com/pingcap/tidb/issues/16788.
func (s *testExpressionRewriterSuite) TestIssue16788(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test;")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(c int);")
tk.MustExec(`insert into t values(1), (NULL);`)
tk.MustQuery("select c, c = 0.5 from t;").Check(testkit.Rows(
"1 0",
"<nil> <nil>",
))
tk.MustQuery("select c, c = '0.5' from t;").Check(testkit.Rows(
"1 0",
"<nil> <nil>",
))
tk.MustQuery("select * from t where c is null;").Check(testkit.Rows(
"<nil>",
))
}
2 changes: 1 addition & 1 deletion planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ func (s *testIntegrationSuite) TestPartitionPruningForInExpr(c *C) {

tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int(11), b int) partition by range (a) (partition p0 values less than (4), partition p1 values less than(10), partition p2 values less than maxvalue);")
tk.MustExec("create table t(a int(11) not null, b int not null) partition by range (a) (partition p0 values less than (4), partition p1 values less than(10), partition p2 values less than maxvalue);")
tk.MustExec("insert into t values (1, 1),(10, 10),(11, 11)")

var input []string
Expand Down