Skip to content

Commit

Permalink
expression: use the correct type when eval decimal and float session …
Browse files Browse the repository at this point in the history
…var (#51395)

close #43527
  • Loading branch information
Rustin170506 authored Feb 29, 2024
1 parent 42c8d2d commit 38ab23b
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 5 deletions.
12 changes: 10 additions & 2 deletions pkg/expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,11 @@ func (b *builtinGetRealVarSig) evalReal(ctx EvalContext, row chunk.Row) (float64
}
varName = strings.ToLower(varName)
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetFloat64(), false, nil
d, err := v.ToFloat64(typeCtx(ctx))
if err != nil {
return 0, false, err
}
return d, false, nil
}
return 0, true, nil
}
Expand Down Expand Up @@ -1092,7 +1096,11 @@ func (b *builtinGetDecimalVarSig) evalDecimal(ctx EvalContext, row chunk.Row) (*
}
varName = strings.ToLower(varName)
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetMysqlDecimal(), false, nil
d, err := v.ToDecimal(typeCtx(ctx))
if err != nil {
return nil, false, err
}
return d, false, nil
}
return nil, true, nil
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,33 @@ func TestGetVar(t *testing.T) {
}
}

func TestTypeConversion(t *testing.T) {
ctx := createContext(t)
// Set value as int64
key := "a"
val := int64(3)
ctx.GetSessionVars().SetUserVarVal(key, types.NewDatum(val))
tp := types.NewFieldType(mysql.TypeLonglong)
ctx.GetSessionVars().SetUserVarType(key, tp)

args := []any{"a"}
// To Decimal.
tp = types.NewFieldType(mysql.TypeNewDecimal)
fn, err := BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
require.NoError(t, err)
d, err := fn.Eval(ctx, chunk.Row{})
require.NoError(t, err)
des := types.NewDecFromInt(3)
require.Equal(t, des, d.GetValue())
// To Float.
tp = types.NewFieldType(mysql.TypeDouble)
fn, err = BuildGetVarFunction(ctx, datumsToConstants(types.MakeDatums(args...))[0], tp)
require.NoError(t, err)
d, err = fn.Eval(ctx, chunk.Row{})
require.NoError(t, err)
require.Equal(t, float64(3), d.GetValue())
}

func TestValues(t *testing.T) {
ctx := createContext(t)
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}
Expand Down
16 changes: 14 additions & 2 deletions pkg/expression/builtin_other_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ func (b *builtinGetRealVarSig) vectorized() bool {
return true
}

// NOTE: get/set variable vectorized eval was disabled. See more in
// https://github.com/pingcap/tidb/pull/8412
func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf0, err := b.bufAllocator.get()
Expand All @@ -406,7 +408,11 @@ func (b *builtinGetRealVarSig) vecEvalReal(ctx EvalContext, input *chunk.Chunk,
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.GetUserVarVal(varName); ok {
f64s[i] = v.GetFloat64()
d, err := v.ToFloat64(typeCtx(ctx))
if err != nil {
return err
}
f64s[i] = d
continue
}
result.SetNull(i, true)
Expand All @@ -418,6 +424,8 @@ func (b *builtinGetDecimalVarSig) vectorized() bool {
return true
}

// NOTE: get/set variable vectorized eval was disabled. See more in
// https://github.com/pingcap/tidb/pull/8412
func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf0, err := b.bufAllocator.get()
Expand All @@ -438,7 +446,11 @@ func (b *builtinGetDecimalVarSig) vecEvalDecimal(ctx EvalContext, input *chunk.C
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.GetUserVarVal(varName); ok {
decs[i] = *v.GetMysqlDecimal()
d, err := v.ToDecimal(typeCtx(ctx))
if err != nil {
return err
}
decs[i] = *d
continue
}
result.SetNull(i, true)
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 24,
shard_count = 25,
deps = [
"//pkg/config",
"//pkg/domain",
Expand Down
24 changes: 24 additions & 0 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2976,3 +2976,27 @@ func TestTiDBRowChecksumBuiltin(t *testing.T) {
tk.MustGetDBError("select tidb_row_checksum() from t", expression.ErrNotSupportedYet)
tk.MustGetDBError("select tidb_row_checksum() from t where id > 0", expression.ErrNotSupportedYet)
}

func TestIssue43527(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create table test (a datetime, b bigint, c decimal(10, 2), d float)")
tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)")
// Decimal.
tk.MustQuery(
"SELECT @total := @total + c FROM (SELECT c FROM test) AS temp, (SELECT @total := 200) AS T1",
).Check(testkit.Rows("300.01"))
// Float.
tk.MustQuery(
"SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := 200) AS T1",
).Check(testkit.Rows("300"))
tk.MustExec("insert into test values('2010-10-10 10:10:10', 100, 100.01, 100)")
// Vectorized.
// NOTE: Because https://github.com/pingcap/tidb/pull/8412 disabled the vectorized execution of get or set variable,
// the following test case will not be executed in vectorized mode.
// It will be executed in the normal mode.
tk.MustQuery(
"SELECT @total := @total + d FROM (SELECT d FROM test) AS temp, (SELECT @total := b FROM test) AS T1 where @total >= 100",
).Check(testkit.Rows("200", "300", "400", "500"))
}

0 comments on commit 38ab23b

Please sign in to comment.