Skip to content

Commit

Permalink
executor: fix LOAD DATA can't use uppercase user var (#41603)
Browse files Browse the repository at this point in the history
close #41596, close #41611
  • Loading branch information
lance6716 authored Feb 21, 2023
1 parent e3896bd commit 65e524a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
3 changes: 3 additions & 0 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []types.Datum) []type
row := make([]types.Datum, 0, len(e.insertColumns))
sessionVars := e.Ctx.GetSessionVars()
setVar := func(name string, col *types.Datum) {
// User variable names are not case-sensitive
// https://dev.mysql.com/doc/refman/8.0/en/user-variables.html
name = strings.ToLower(name)
if col == nil || col.IsNull() {
sessionVars.UnsetUserVar(name)
} else {
Expand Down
1 change: 1 addition & 0 deletions executor/loadremotetest/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestMain(m *testing.M) {
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
goleak.IgnoreTopFunction("net.(*netFD).connect.func2"),
goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"),
}
goleak.VerifyTestMain(m, opts...)
}
20 changes: 20 additions & 0 deletions executor/writetest/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,26 @@ func TestLoadDataOverflowBigintUnsigned(t *testing.T) {
checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL)
}

func TestLoadDataWithUppercaseUserVars(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test; drop table if exists load_data_test;")
tk.MustExec("CREATE TABLE load_data_test (a int, b int);")
tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (@V1)" +
" set a = @V1, b = @V1*100")
ctx := tk.Session().(sessionctx.Context)
ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo)
require.True(t, ok)
defer ctx.SetValue(executor.LoadDataVarKey, nil)
require.NotNil(t, ld)
tests := []testCase{
{[]byte("1\n2\n"), []string{"1|100", "2|200"}, "Records: 2 Deleted: 0 Skipped: 0 Warnings: 0"},
}
deleteSQL := "delete from load_data_test"
selectSQL := "select * from load_data_test;"
checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL)
}

func TestLoadDataIntoPartitionedTable(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
63 changes: 31 additions & 32 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1556,39 +1556,38 @@ func (cli *testServerClient) runTestLoadData(t *testing.T, server *Server) {
dbt.MustExec("drop table if exists pn")
})

// TODO: disabled
// Test with upper case variables.
//cli.runTestsOnNewDB(t, func(config *mysql.Config) {
// config.AllowAllFiles = true
// config.Params["sql_mode"] = "''"
//}, "LoadData", func(dbt *testkit.DBTestKit) {
// dbt.MustExec("drop table if exists pn")
// dbt.MustExec("create table pn (c1 int, c2 int, c3 int)")
// dbt.MustExec("set @@tidb_dml_batch_size = 1")
// _, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @VAL1, @VAL2) SET c3 = @VAL2 * 100, c2 = CAST(@VAL1 AS UNSIGNED)`, path))
// require.NoError(t, err1)
// var (
// a int
// b int
// c int
// )
// rows := dbt.MustQuery("select * from pn")
// require.Truef(t, rows.Next(), "unexpected data")
// err = rows.Scan(&a, &b, &c)
// require.NoError(t, err)
// require.Equal(t, 1, a)
// require.Equal(t, 2, b)
// require.Equal(t, 300, c)
// require.Truef(t, rows.Next(), "unexpected data")
// err = rows.Scan(&a, &b, &c)
// require.NoError(t, err)
// require.Equal(t, 4, a)
// require.Equal(t, 5, b)
// require.Equal(t, 600, c)
// require.Falsef(t, rows.Next(), "unexpected data")
// require.NoError(t, rows.Close())
// dbt.MustExec("drop table if exists pn")
//})
cli.runTestsOnNewDB(t, func(config *mysql.Config) {
config.AllowAllFiles = true
config.Params["sql_mode"] = "''"
}, "LoadData", func(dbt *testkit.DBTestKit) {
dbt.MustExec("drop table if exists pn")
dbt.MustExec("create table pn (c1 int, c2 int, c3 int)")
dbt.MustExec("set @@tidb_dml_batch_size = 1")
_, err1 := dbt.GetDB().Exec(fmt.Sprintf(`load data local infile %q into table pn FIELDS TERMINATED BY ',' (c1, @VAL1, @VAL2) SET c3 = @VAL2 * 100, c2 = CAST(@VAL1 AS UNSIGNED)`, path))
require.NoError(t, err1)
var (
a int
b int
c int
)
rows := dbt.MustQuery("select * from pn")
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&a, &b, &c)
require.NoError(t, err)
require.Equal(t, 1, a)
require.Equal(t, 2, b)
require.Equal(t, 300, c)
require.Truef(t, rows.Next(), "unexpected data")
err = rows.Scan(&a, &b, &c)
require.NoError(t, err)
require.Equal(t, 4, a)
require.Equal(t, 5, b)
require.Equal(t, 600, c)
require.Falsef(t, rows.Next(), "unexpected data")
require.NoError(t, rows.Close())
dbt.MustExec("drop table if exists pn")
})
}

func (cli *testServerClient) runTestConcurrentUpdate(t *testing.T) {
Expand Down

0 comments on commit 65e524a

Please sign in to comment.