diff --git a/expression/builtin_other.go b/expression/builtin_other.go index b820519a0e8ee..ee30bb0c6ace3 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -613,18 +613,29 @@ func (b *builtinValuesStringSig) evalString(_ types.Row) (string, bool, error) { if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { return "", true, nil } + values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return "", true, errors.New("Session current insert values is nil") } + row := values.(types.Row) - if b.offset < row.Len() { - if row.IsNull(b.offset) { - return "", true, nil - } - return row.GetString(b.offset), false, nil + if b.offset >= row.Len() { + return "", true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset) } - return "", true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", row.Len(), b.offset) + + if row.IsNull(b.offset) { + return "", true, nil + } + + // Specially handle the ENUM/SET/BIT input value. + if retType := b.getRetTp(); retType.Hybrid() { + val := row.GetDatum(b.offset, retType) + res, err := val.ToString() + return res, err != nil, err + } + + return row.GetString(b.offset), false, nil } type builtinValuesTimeSig struct { diff --git a/expression/column.go b/expression/column.go index 9e2ae29745e0c..543de44b58232 100644 --- a/expression/column.go +++ b/expression/column.go @@ -225,18 +225,14 @@ func (col *Column) EvalString(ctx sessionctx.Context, row types.Row) (string, bo if row.IsNull(col.Index) { return "", true, nil } + + // Specially handle the ENUM/SET/BIT input value. if col.GetType().Hybrid() { val := row.GetDatum(col.Index, col.RetType) - if val.IsNull() { - return "", true, nil - } res, err := val.ToString() - resLen := len([]rune(res)) - if ctx.GetSessionVars().StmtCtx.PadCharToFullLength && col.GetType().Tp == mysql.TypeString && resLen < col.RetType.Flen { - res = res + strings.Repeat(" ", col.RetType.Flen-resLen) - } - return res, err != nil, errors.Trace(err) + return res, err != nil, err } + val := row.GetString(col.Index) if ctx.GetSessionVars().StmtCtx.PadCharToFullLength && col.GetType().Tp == mysql.TypeString { valLen := len([]rune(val)) diff --git a/expression/integration_test.go b/expression/integration_test.go index e6805593f84d3..563cfa4871b6b 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3523,3 +3523,14 @@ func (s *testIntegrationSuite) TestValuesFloat32(c *C) { tk.MustExec(`insert into t values (1, 0.02) on duplicate key update j = values (j);`) tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 0.02`)) } + +func (s *testIntegrationSuite) TestValuesEnum(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t (a bigint primary key, b enum('a','b','c'));`) + tk.MustExec(`insert into t values (1, "a");`) + tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 a`)) + tk.MustExec(`insert into t values (1, "b") on duplicate key update b = values(b);`) + tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 b`)) +}