Skip to content

Commit

Permalink
expression: pass const bool to all calls for expression.ConstItem (#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Dec 25, 2023
1 parent b719406 commit 07e9ece
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error
return errors.New("APPROX_PERCENTILE should take 2 arguments")
}

if !a.Args[1].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if !a.Args[1].ConstItem(false) {
return errors.New("APPROX_PERCENTILE should take a constant expression as percentage argument")
}
percent, isNull, err := a.Args[1].EvalInt(ctx, chunk.Row{})
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,7 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
tp.AddFlag(expr.GetType().GetFlag() & (mysql.UnsignedFlag | mysql.NotNullFlag))
castExpr := BuildCastFunction(ctx, expr, tp)
// For const item, we can use find-grained precision and scale by the result.
if castExpr.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if castExpr.ConstItem(true) {
val, isnull, err := castExpr.EvalDecimal(ctx, chunk.Row{})
if !isnull && err == nil {
precision, frac := val.PrecisionAndFrac()
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/builtin_encryption_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (b *builtinAesDecryptSig) vecEvalString(ctx EvalContext, input *chunk.Chunk
}

isWarning := !b.ivRequired && len(b.args) == 3
isConstKey := b.args[1].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache)
isConstKey := b.args[1].ConstItem(false)

var key []byte
if isConstKey {
Expand Down Expand Up @@ -158,7 +158,7 @@ func (b *builtinAesEncryptIVSig) vecEvalString(ctx EvalContext, input *chunk.Chu
return errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}

isConst := b.args[1].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache)
isConst := b.args[1].ConstItem(false)
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down Expand Up @@ -331,7 +331,7 @@ func (b *builtinAesDecryptIVSig) vecEvalString(ctx EvalContext, input *chunk.Chu
return errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}

isConst := b.args[1].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache)
isConst := b.args[1].ConstItem(false)
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down Expand Up @@ -672,7 +672,7 @@ func (b *builtinAesEncryptSig) vecEvalString(ctx EvalContext, input *chunk.Chunk
}

isWarning := !b.ivRequired && len(b.args) == 3
isConst := b.args[1].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache)
isConst := b.args[1].ConstItem(false)
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_ilike_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (b *builtinIlikeSig) getEscape(ctx EvalContext, input *chunk.Chunk, result
rowNum := input.NumRows()
escape := int64('\\')

if !b.args[2].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if !b.args[2].ConstItem(true) {
return escape, true, errors.Errorf("escape should be const")
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[int64]bool, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalInt(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -290,7 +290,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er
b.hashSet = set.NewStringSet()
collator := collate.GetCollator(b.collation)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalString(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -363,7 +363,7 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewFloat64Set()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalReal(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -434,7 +434,7 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewStringSet()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalDecimal(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -514,7 +514,7 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[types.CoreTime]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalTime(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -585,7 +585,7 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context)
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[time.Duration]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(ctx.GetSessionVars().StmtCtx.UseCache) {
if b.args[i].ConstItem(true) {
val, isNull, err := b.args[i].EvalDuration(ctx, chunk.Row{})
if err != nil {
return err
Expand Down
33 changes: 24 additions & 9 deletions pkg/expression/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,30 @@ func TestIsBinaryLiteral(t *testing.T) {
}

func TestConstItem(t *testing.T) {
ctx := createContext(t)
sf := newFunctionWithMockCtx(ast.Rand)
require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache))
sf = newFunctionWithMockCtx(ast.UUID)
require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache))
sf = newFunctionWithMockCtx(ast.GetParam, NewOne())
require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache))
sf = newFunctionWithMockCtx(ast.Abs, NewOne())
require.True(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache))
const noConst int = 0
const constInCtx int = 1
const constStrict int = 2

ctxConst := NewZero()
ctxConst.DeferredExpr = newFunctionWithMockCtx(ast.UnixTimestamp)
for _, c := range []struct {
exp Expression
constItem int
}{
{newFunctionWithMockCtx(ast.Rand), noConst},
{newFunctionWithMockCtx(ast.UUID), noConst},
{newFunctionWithMockCtx(ast.GetParam, NewOne()), noConst},
{newFunctionWithMockCtx(ast.Abs, NewOne()), constStrict},
{newFunctionWithMockCtx(ast.Abs, newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), NewOne()), constStrict},
{newFunctionWithMockCtx(ast.Plus, newColumn(1), NewOne()), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), ctxConst), constInCtx},
} {
require.Equal(t, c.constItem >= constInCtx, c.exp.ConstItem(false), c.exp.String())
require.Equal(t, c.constItem >= constStrict, c.exp.ConstItem(true), c.exp.String())
}
}

func TestVectorizable(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/scalar_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestScalarFunction(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, []byte{0x22, 0x6c, 0x74, 0x28, 0x43, 0x6f, 0x6c, 0x75, 0x6d, 0x6e, 0x23, 0x31, 0x2c, 0x20, 0x31, 0x29, 0x22}, res)
require.False(t, sf.IsCorrelated())
require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx.UseCache))
require.False(t, sf.ConstItem(false))
require.True(t, sf.Decorrelate(nil).Equal(ctx, sf))
require.EqualValues(t, []byte{0x3, 0x4, 0x6c, 0x74, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x5, 0xbf, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, sf.HashCode())

Expand Down
37 changes: 37 additions & 0 deletions tests/integrationtest/r/expression/plan_cache.result
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,43 @@ admin reload expr_pushdown_blacklist;
set tidb_enable_prepared_plan_cache=default;
set tidb_enable_vectorized_expression=default;
set tidb_enable_prepared_plan_cache=ON;
drop table if exists t1;
create table t1 (a varchar(40));
insert into t1 values ('a'),('b');
insert into mysql.expr_pushdown_blacklist values('aes_encrypt', 'tikv,tiflash,tidb', 'for test');
admin reload expr_pushdown_blacklist;
set tidb_enable_vectorized_expression=ON;
prepare stmt1 from 'select a, hex(aes_encrypt(a, ?)) from t1 order by a asc';
set @a='xx';
execute stmt1 using @a;
a hex(aes_encrypt(a, ?))
a DA767CA0BE9CE9A1A979F6169A84B712
b 56F19741AA9177000269D07B6C4C6D7D
set @a='yy';
execute stmt1 using @a;
a hex(aes_encrypt(a, ?))
a 1318DA9E3BFC5FBEF34E5ACAFA944B09
b 1670ED5A2E8650BBCB09D7DF67B29FFC
execute stmt2 using @a;
a
set tidb_enable_vectorized_expression=OFF;
set @a='xx';
execute stmt1 using @a;
a hex(aes_encrypt(a, ?))
a DA767CA0BE9CE9A1A979F6169A84B712
b 56F19741AA9177000269D07B6C4C6D7D
set @a='yy';
execute stmt1 using @a;
a hex(aes_encrypt(a, ?))
a 1318DA9E3BFC5FBEF34E5ACAFA944B09
b 1670ED5A2E8650BBCB09D7DF67B29FFC
execute stmt2 using @a;
a
delete from mysql.expr_pushdown_blacklist where name like 'aes_%' and store_type = 'tikv,tiflash,tidb' and reason = 'for test';
admin reload expr_pushdown_blacklist;
set tidb_enable_prepared_plan_cache=default;
set tidb_enable_vectorized_expression=default;
set tidb_enable_prepared_plan_cache=ON;
drop table if exists t;
create table t(col_int int);
insert into t values(null);
Expand Down
25 changes: 25 additions & 0 deletions tests/integrationtest/t/expression/plan_cache.test
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,31 @@ admin reload expr_pushdown_blacklist;
set tidb_enable_prepared_plan_cache=default;
set tidb_enable_vectorized_expression=default;

# TestCacheAes
set tidb_enable_prepared_plan_cache=ON;
drop table if exists t1;
create table t1 (a varchar(40));
insert into t1 values ('a'),('b');
insert into mysql.expr_pushdown_blacklist values('aes_encrypt', 'tikv,tiflash,tidb', 'for test');
admin reload expr_pushdown_blacklist;
set tidb_enable_vectorized_expression=ON;
prepare stmt1 from 'select a, hex(aes_encrypt(a, ?)) from t1 order by a asc';
set @a='xx';
execute stmt1 using @a;
set @a='yy';
execute stmt1 using @a;
execute stmt2 using @a;
set tidb_enable_vectorized_expression=OFF;
set @a='xx';
execute stmt1 using @a;
set @a='yy';
execute stmt1 using @a;
execute stmt2 using @a;
delete from mysql.expr_pushdown_blacklist where name like 'aes_%' and store_type = 'tikv,tiflash,tidb' and reason = 'for test';
admin reload expr_pushdown_blacklist;
set tidb_enable_prepared_plan_cache=default;
set tidb_enable_vectorized_expression=default;

# TestCacheRefineArgs
set tidb_enable_prepared_plan_cache=ON;
drop table if exists t;
Expand Down

0 comments on commit 07e9ece

Please sign in to comment.