From fa16c09c711bcc6d4fd4361a2ccabd10c330ac53 Mon Sep 17 00:00:00 2001 From: spongedc Date: Fri, 20 Jul 2018 16:34:18 +0800 Subject: [PATCH 01/11] sessionctx, executor: Add correctness check when set system variables --- executor/ddl_test.go | 13 +++--- session/session.go | 15 +++++-- sessionctx/variable/session.go | 3 +- sessionctx/variable/sysvar.go | 16 +++++--- sessionctx/variable/varsutil.go | 60 +++++++++++++++++++++++++++- sessionctx/variable/varsutil_test.go | 6 +-- 6 files changed, 92 insertions(+), 21 deletions(-) diff --git a/executor/ddl_test.go b/executor/ddl_test.go index b25211a8baa1e..f91bcf6d3be8e 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/testkit" @@ -409,17 +410,15 @@ func (s *testSuite) TestSetDDLReorgWorkerCnt(c *C) { c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(1)) tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = invalid_val") - c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(variable.DefTiDBDDLReorgWorkerCount)) + _, err := tk.Exec("set tidb_ddl_reorg_worker_cnt = invalid_val") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) - tk.MustExec("set tidb_ddl_reorg_worker_cnt = -1") - c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(variable.DefTiDBDDLReorgWorkerCount)) + _, err = tk.Exec("set tidb_ddl_reorg_worker_cnt = -1") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) - res := tk.MustQuery("select @@tidb_ddl_reorg_worker_cnt") - res.Check(testkit.Rows("-1")) tk.MustExec("set tidb_ddl_reorg_worker_cnt = 100") - res = tk.MustQuery("select @@tidb_ddl_reorg_worker_cnt") + res := tk.MustQuery("select @@tidb_ddl_reorg_worker_cnt") res.Check(testkit.Rows("100")) res = tk.MustQuery("select @@global.tidb_ddl_reorg_worker_cnt") diff --git a/session/session.go b/session/session.go index 066b9427004d5..ff3ec7be54e2d 100644 --- a/session/session.go +++ b/session/session.go @@ -689,16 +689,25 @@ func (s *session) GetGlobalSysVar(name string) (string, error) { } // SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. -func (s *session) SetGlobalSysVar(name string, value string) error { +func (s *session) SetGlobalSysVar(name, value string) error { if name == variable.SQLModeVar { value = mysql.FormatSQLModeStr(value) if _, err := mysql.GetSQLMode(value); err != nil { return errors.Trace(err) } } + var sVal string + var err, warn error + sVal, warn, err = variable.ValidateSetSystemVar(name, value) + if err != nil { + return errors.Trace(err) + } + if warn != nil { + s.sessionVars.StmtCtx.AppendWarning(warn) + } sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, - mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), value) - _, _, err := s.ExecRestrictedSQL(s, sql) + mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), sVal) + _, _, err = s.ExecRestrictedSQL(s, sql) return errors.Trace(err) } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 05bbd02b538bf..c792411e42b3e 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -545,8 +545,7 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { case TiDBEnableTablePartition: s.EnableTablePartition = TiDBOptOn(val) case TiDBDDLReorgWorkerCount: - workerCnt := tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount) - SetDDLReorgWorkerCounter(int32(workerCnt)) + SetDDLReorgWorkerCounter(int32(tidbOptPositiveInt32(val, DefTiDBDDLReorgWorkerCount))) } s.systems[name] = val return nil diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index c2112ccf811c3..ea605228702ef 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -63,15 +63,19 @@ const ( CodeIncorrectScope terror.ErrCode = mysql.ErrIncorrectGlobalLocalVar CodeUnknownTimeZone terror.ErrCode = mysql.ErrUnknownTimeZone CodeReadOnly terror.ErrCode = mysql.ErrVariableIsReadonly + CodeWrongValueForVar terror.ErrCode = mysql.ErrWrongValueForVar + CodeWrongTypeForVar terror.ErrCode = mysql.ErrWrongTypeForVar ) // Variable errors var ( - UnknownStatusVar = terror.ClassVariable.New(CodeUnknownStatusVar, "unknown status variable") - UnknownSystemVar = terror.ClassVariable.New(CodeUnknownSystemVar, mysql.MySQLErrName[mysql.ErrUnknownSystemVariable]) - ErrIncorrectScope = terror.ClassVariable.New(CodeIncorrectScope, mysql.MySQLErrName[mysql.ErrIncorrectGlobalLocalVar]) - ErrUnknownTimeZone = terror.ClassVariable.New(CodeUnknownTimeZone, mysql.MySQLErrName[mysql.ErrUnknownTimeZone]) - ErrReadOnly = terror.ClassVariable.New(CodeReadOnly, "variable is read only") + UnknownStatusVar = terror.ClassVariable.New(CodeUnknownStatusVar, "unknown status variable") + UnknownSystemVar = terror.ClassVariable.New(CodeUnknownSystemVar, mysql.MySQLErrName[mysql.ErrUnknownSystemVariable]) + ErrIncorrectScope = terror.ClassVariable.New(CodeIncorrectScope, mysql.MySQLErrName[mysql.ErrIncorrectGlobalLocalVar]) + ErrUnknownTimeZone = terror.ClassVariable.New(CodeUnknownTimeZone, mysql.MySQLErrName[mysql.ErrUnknownTimeZone]) + ErrReadOnly = terror.ClassVariable.New(CodeReadOnly, "variable is read only") + ErrWrongValueForVar = terror.ClassVariable.New(CodeWrongValueForVar, mysql.MySQLErrName[mysql.ErrWrongValueForVar]) + ErrWrongTypeForVar = terror.ClassVariable.New(CodeWrongTypeForVar, mysql.MySQLErrName[mysql.ErrWrongTypeForVar]) ) func init() { @@ -87,6 +91,8 @@ func init() { CodeIncorrectScope: mysql.ErrIncorrectGlobalLocalVar, CodeUnknownTimeZone: mysql.ErrUnknownTimeZone, CodeReadOnly: mysql.ErrVariableIsReadonly, + CodeWrongValueForVar: mysql.ErrWrongValueForVar, + CodeWrongTypeForVar: mysql.ErrWrongTypeForVar, } terror.ErrClassToMySQLCodes[terror.ClassVariable] = mySQLErrCodes } diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 8e93112bdac50..d79d265eff3af 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -130,10 +130,19 @@ func SetSessionSystemVar(vars *SessionVars, name string, value types.Datum) erro if value.IsNull() { return vars.deleteSystemVar(name) } - sVal, err := value.ToString() + var sVal string + var err, warn error + sVal, err = value.ToString() if err != nil { return errors.Trace(err) } + sVal, warn, err = ValidateSetSystemVar(name, sVal) + if err != nil { + return errors.Trace(err) + } + if warn != nil { + vars.StmtCtx.AppendWarning(warn) + } return vars.SetSystemVar(name, sVal) } @@ -156,6 +165,55 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { return nil } +// ValidateSetSystemVar checks if system variable satisfies specific restriction. +func ValidateSetSystemVar(name string, value string) (string, error, error) { + switch name { + // opt + case AutocommitVar, TiDBImportingData, TiDBSkipUTF8Check, TiDBOptAggPushDown, + TiDBOptInSubqUnFolding, TiDBEnableTablePartition, + TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, + TiDBBatchDelete: + if strings.EqualFold(value, "ON") || value == "1" || strings.EqualFold(value, "OFF") || value == "0" { + return value, nil, nil + } + return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + case TiDBIndexLookupConcurrency, TiDBIndexLookupJoinConcurrency, TiDBIndexJoinBatchSize, + TiDBIndexLookupSize, + TiDBHashJoinConcurrency, + TiDBHashAggPartialConcurrency, + TiDBHashAggFinalConcurrency, + TiDBDistSQLScanConcurrency, + TiDBIndexSerialScanConcurrency, TiDBDDLReorgWorkerCount, + TiDBBackoffLockFast, TiDBMaxChunkSize, + TiDBDMLBatchSize, TiDBOptimizerSelectivityLevel, + TiDBGeneralLog: + v, err := strconv.Atoi(value) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if v <= 0 { + return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + } + return value, nil, nil + case TiDBProjectionConcurrency, + TIDBMemQuotaQuery, + TIDBMemQuotaHashJoin, + TIDBMemQuotaMergeJoin, + TIDBMemQuotaSort, + TIDBMemQuotaTopn, + TIDBMemQuotaIndexLookupReader, + TIDBMemQuotaIndexLookupJoin, + TIDBMemQuotaNestedLoopApply, + TiDBRetryLimit: + _, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + return value, nil, nil + } + return value, nil, nil +} + // TiDBOptOn could be used for all tidb session variable options, we use "ON"/1 to turn on those options. func TiDBOptOn(opt string) bool { return strings.EqualFold(opt, "ON") || opt == "1" diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 4ec98a2f26bb8..dce99c095bd6c 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -223,11 +223,11 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(1)) c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(1)) - SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(-1)) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(DefTiDBDDLReorgWorkerCount)) + err = SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(-1)) + c.Assert(terror.ErrorEqual(err, ErrWrongValueForVar), IsTrue) SetSessionSystemVar(v, TiDBDDLReorgWorkerCount, types.NewIntDatum(int64(maxDDLReorgWorkerCount)+1)) - c.Assert(GetDDLReorgWorkerCounter(), Equals, int32(maxDDLReorgWorkerCount)) + c.Assert(terror.ErrorEqual(err, ErrWrongValueForVar), IsTrue) err = SetSessionSystemVar(v, TiDBRetryLimit, types.NewStringDatum("3")) c.Assert(err, IsNil) From 394a7ad50707b2fcce1241e2d2f3e4deaf6deea7 Mon Sep 17 00:00:00 2001 From: spongedc Date: Fri, 20 Jul 2018 17:55:34 +0800 Subject: [PATCH 02/11] Add tests --- executor/set_test.go | 31 +++++++++++++++++++++++++++++++ sessionctx/variable/varsutil.go | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/executor/set_test.go b/executor/set_test.go index 064057f22acbf..fa6f6d71feee5 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -17,6 +17,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/testkit" "golang.org/x/net/context" ) @@ -247,3 +248,33 @@ func (s *testSuite) TestSetCharset(c *C) { // Issue 1523 tk.MustExec(`SET NAMES binary`) } + +func (s *testSuite) TestValidateSetVar(c *C) { + tk := testkit.NewTestKit(c, s.store) + + _, err := tk.Exec("set global tidb_distsql_scan_concurrency='fff';") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + _, err = tk.Exec("set global tidb_distsql_scan_concurrency=-1;") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) + + _, err = tk.Exec("set @@tidb_distsql_scan_concurrency='fff';") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + _, err = tk.Exec("set @@tidb_distsql_scan_concurrency=-1;") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) + + _, err = tk.Exec("set @@tidb_batch_delete='ok';") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) + + tk.MustExec("set @@tidb_batch_delete='On';") + tk.MustExec("set @@tidb_batch_delete='oFf';") + tk.MustExec("set @@tidb_batch_delete=1;") + tk.MustExec("set @@tidb_batch_delete=0;") + + _, err = tk.Exec("set @@tidb_batch_delete=3;") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) + + _, err = tk.Exec("set @@tidb_mem_quota_mergejoin='tidb';") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) +} diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index d79d265eff3af..3d1c998822131 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -207,7 +207,7 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { TiDBRetryLimit: _, err := strconv.ParseInt(value, 10, 64) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, nil, ErrWrongValueForVar.GenByArgs(name) } return value, nil, nil } From 88a630a5a1138d95c8fb6b060378ff5cf761ebd8 Mon Sep 17 00:00:00 2001 From: spongedc Date: Fri, 20 Jul 2018 23:39:27 +0800 Subject: [PATCH 03/11] Code format refine --- executor/set_test.go | 33 +++++++++++-- sessionctx/variable/sysvar.go | 83 +++++++++++++++++++++------------ sessionctx/variable/varsutil.go | 39 +++++++++++++++- 3 files changed, 119 insertions(+), 36 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index fa6f6d71feee5..780d5a2d7dcba 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" "golang.org/x/net/context" ) @@ -213,15 +214,15 @@ func (s *testSuite) TestSetVar(c *C) { tk.MustQuery("select @@session.tx_isolation").Check(testkit.Rows("READ-COMMITTED")) tk.MustExec("set global avoid_temporal_upgrade = on") - tk.MustQuery(`select @@global.avoid_temporal_upgrade;`).Check(testkit.Rows("ON")) + tk.MustQuery(`select @@global.avoid_temporal_upgrade;`).Check(testkit.Rows("1")) tk.MustExec("set @@global.avoid_temporal_upgrade = off") - tk.MustQuery(`select @@global.avoid_temporal_upgrade;`).Check(testkit.Rows("off")) + tk.MustQuery(`select @@global.avoid_temporal_upgrade;`).Check(testkit.Rows("0")) tk.MustExec("set session sql_log_bin = on") - tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("ON")) + tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("1")) tk.MustExec("set sql_log_bin = off") - tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("off")) + tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("0")) tk.MustExec("set @@sql_log_bin = on") - tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("ON")) + tk.MustQuery(`select @@session.sql_log_bin;`).Check(testkit.Rows("1")) } func (s *testSuite) TestSetCharset(c *C) { @@ -277,4 +278,26 @@ func (s *testSuite) TestValidateSetVar(c *C) { _, err = tk.Exec("set @@tidb_mem_quota_mergejoin='tidb';") c.Assert(terror.ErrorEqual(err, variable.ErrWrongValueForVar), IsTrue) + + tk.MustExec("set @@group_concat_max_len=1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect group_concat_max_len value: '1'")) + result := tk.MustQuery("select @@group_concat_max_len;") + result.Check(testkit.Rows("4")) + + _, err = tk.Exec("set @@group_concat_max_len = 18446744073709551616") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + // Test illegal type + _, err = tk.Exec("set @@group_concat_max_len='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + + tk.MustExec("set @@default_week_format=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect default_week_format value: '-1'")) + result = tk.MustQuery("select @@default_week_format;") + result.Check(testkit.Rows("0")) + + tk.MustExec("set @@default_week_format=9") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect default_week_format value: '9'")) + result = tk.MustQuery("select @@default_week_format;") + result.Check(testkit.Rows("7")) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index ea605228702ef..d81275f2c73f6 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -58,24 +58,26 @@ func GetSysVar(name string) *SysVar { // Variable error codes. const ( - CodeUnknownStatusVar terror.ErrCode = 1 - CodeUnknownSystemVar terror.ErrCode = mysql.ErrUnknownSystemVariable - CodeIncorrectScope terror.ErrCode = mysql.ErrIncorrectGlobalLocalVar - CodeUnknownTimeZone terror.ErrCode = mysql.ErrUnknownTimeZone - CodeReadOnly terror.ErrCode = mysql.ErrVariableIsReadonly - CodeWrongValueForVar terror.ErrCode = mysql.ErrWrongValueForVar - CodeWrongTypeForVar terror.ErrCode = mysql.ErrWrongTypeForVar + CodeUnknownStatusVar terror.ErrCode = 1 + CodeUnknownSystemVar terror.ErrCode = mysql.ErrUnknownSystemVariable + CodeIncorrectScope terror.ErrCode = mysql.ErrIncorrectGlobalLocalVar + CodeUnknownTimeZone terror.ErrCode = mysql.ErrUnknownTimeZone + CodeReadOnly terror.ErrCode = mysql.ErrVariableIsReadonly + CodeWrongValueForVar terror.ErrCode = mysql.ErrWrongValueForVar + CodeWrongTypeForVar terror.ErrCode = mysql.ErrWrongTypeForVar + CodeTruncatedWrongValue terror.ErrCode = mysql.ErrTruncatedWrongValue ) // Variable errors var ( - UnknownStatusVar = terror.ClassVariable.New(CodeUnknownStatusVar, "unknown status variable") - UnknownSystemVar = terror.ClassVariable.New(CodeUnknownSystemVar, mysql.MySQLErrName[mysql.ErrUnknownSystemVariable]) - ErrIncorrectScope = terror.ClassVariable.New(CodeIncorrectScope, mysql.MySQLErrName[mysql.ErrIncorrectGlobalLocalVar]) - ErrUnknownTimeZone = terror.ClassVariable.New(CodeUnknownTimeZone, mysql.MySQLErrName[mysql.ErrUnknownTimeZone]) - ErrReadOnly = terror.ClassVariable.New(CodeReadOnly, "variable is read only") - ErrWrongValueForVar = terror.ClassVariable.New(CodeWrongValueForVar, mysql.MySQLErrName[mysql.ErrWrongValueForVar]) - ErrWrongTypeForVar = terror.ClassVariable.New(CodeWrongTypeForVar, mysql.MySQLErrName[mysql.ErrWrongTypeForVar]) + UnknownStatusVar = terror.ClassVariable.New(CodeUnknownStatusVar, "unknown status variable") + UnknownSystemVar = terror.ClassVariable.New(CodeUnknownSystemVar, mysql.MySQLErrName[mysql.ErrUnknownSystemVariable]) + ErrIncorrectScope = terror.ClassVariable.New(CodeIncorrectScope, mysql.MySQLErrName[mysql.ErrIncorrectGlobalLocalVar]) + ErrUnknownTimeZone = terror.ClassVariable.New(CodeUnknownTimeZone, mysql.MySQLErrName[mysql.ErrUnknownTimeZone]) + ErrReadOnly = terror.ClassVariable.New(CodeReadOnly, "variable is read only") + ErrWrongValueForVar = terror.ClassVariable.New(CodeWrongValueForVar, mysql.MySQLErrName[mysql.ErrWrongValueForVar]) + ErrWrongTypeForVar = terror.ClassVariable.New(CodeWrongTypeForVar, mysql.MySQLErrName[mysql.ErrWrongTypeForVar]) + ErrTruncatedWrongValue = terror.ClassVariable.New(CodeTruncatedWrongValue, mysql.MySQLErrName[mysql.ErrTruncatedWrongValue]) ) func init() { @@ -87,12 +89,13 @@ func init() { // Register terror to mysql error map. mySQLErrCodes := map[terror.ErrCode]uint16{ - CodeUnknownSystemVar: mysql.ErrUnknownSystemVariable, - CodeIncorrectScope: mysql.ErrIncorrectGlobalLocalVar, - CodeUnknownTimeZone: mysql.ErrUnknownTimeZone, - CodeReadOnly: mysql.ErrVariableIsReadonly, - CodeWrongValueForVar: mysql.ErrWrongValueForVar, - CodeWrongTypeForVar: mysql.ErrWrongTypeForVar, + CodeUnknownSystemVar: mysql.ErrUnknownSystemVariable, + CodeIncorrectScope: mysql.ErrIncorrectGlobalLocalVar, + CodeUnknownTimeZone: mysql.ErrUnknownTimeZone, + CodeReadOnly: mysql.ErrVariableIsReadonly, + CodeWrongValueForVar: mysql.ErrWrongValueForVar, + CodeWrongTypeForVar: mysql.ErrWrongTypeForVar, + CodeTruncatedWrongValue: mysql.ErrTruncatedWrongValue, } terror.ErrClassToMySQLCodes[terror.ClassVariable] = mySQLErrCodes } @@ -117,7 +120,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "old_passwords", "0"}, {ScopeNone, "innodb_version", "5.6.25"}, {ScopeGlobal, "max_connections", "151"}, - {ScopeGlobal | ScopeSession, "big_tables", "OFF"}, + {ScopeGlobal | ScopeSession, BigTables, "0"}, {ScopeNone, "skip_external_locking", "ON"}, {ScopeGlobal, "slave_pending_jobs_size_max", "16777216"}, {ScopeNone, "innodb_sync_array_size", "1"}, @@ -127,7 +130,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "sql_select_limit", "18446744073709551615"}, {ScopeGlobal, "ndb_show_foreign_key_mock_tables", ""}, {ScopeNone, "multi_range_count", "256"}, - {ScopeGlobal | ScopeSession, "default_week_format", "0"}, + {ScopeGlobal | ScopeSession, DefaultWeekFormat, "0"}, {ScopeGlobal | ScopeSession, "binlog_error_action", "IGNORE_ERROR"}, {ScopeGlobal, "slave_transaction_retries", "10"}, {ScopeGlobal | ScopeSession, "default_storage_engine", "InnoDB"}, @@ -154,7 +157,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_max_purge_lag", "0"}, {ScopeGlobal | ScopeSession, "preload_buffer_size", "32768"}, {ScopeGlobal, "slave_checkpoint_period", "300"}, - {ScopeGlobal, "check_proxy_users", ""}, + {ScopeGlobal, CheckProxyUsers, "0"}, {ScopeNone, "have_query_cache", "YES"}, {ScopeGlobal, "innodb_flush_log_at_timeout", "1"}, {ScopeGlobal, "innodb_max_undo_log_size", ""}, @@ -170,7 +173,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "innodb_ft_sort_pll_degree", "2"}, {ScopeNone, "thread_stack", "262144"}, {ScopeGlobal, "relay_log_info_repository", "FILE"}, - {ScopeGlobal | ScopeSession, "sql_log_bin", "ON"}, + {ScopeGlobal | ScopeSession, SQLLogBin, "1"}, {ScopeGlobal, "super_read_only", "OFF"}, {ScopeGlobal | ScopeSession, "max_delayed_threads", "20"}, {ScopeNone, "protocol_version", "10"}, @@ -189,7 +192,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_log_write_ahead_size", ""}, {ScopeNone, "innodb_log_group_home_dir", "./"}, {ScopeNone, "performance_schema_events_statements_history_size", "10"}, - {ScopeGlobal, "general_log", "OFF"}, + {ScopeGlobal, GeneralLog, "0"}, {ScopeGlobal, "validate_password_dictionary_file", ""}, {ScopeGlobal, "binlog_order_commits", "ON"}, {ScopeGlobal, "master_verify_checksum", "OFF"}, @@ -219,7 +222,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "key_buffer_size", "8388608"}, {ScopeGlobal | ScopeSession, "foreign_key_checks", "ON"}, {ScopeGlobal, "host_cache_size", "279"}, - {ScopeGlobal, "delay_key_write", "ON"}, + {ScopeGlobal, DelayKeyWrite, "ON"}, {ScopeNone, "metadata_locks_cache_size", "1024"}, {ScopeNone, "innodb_force_recovery", "0"}, {ScopeGlobal, "innodb_file_format_max", "Antelope"}, @@ -322,7 +325,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "ndb_optimization_delay", ""}, {ScopeGlobal, "innodb_ft_num_word_optimize", "2000"}, {ScopeGlobal | ScopeSession, "max_join_size", "18446744073709551615"}, - {ScopeNone, "core_file", "OFF"}, + {ScopeNone, CoreFile, "0"}, {ScopeGlobal | ScopeSession, "max_seeks_for_key", "18446744073709551615"}, {ScopeNone, "innodb_log_buffer_size", "8388608"}, {ScopeGlobal, "delayed_insert_timeout", "300"}, @@ -564,7 +567,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "back_log", "80"}, {ScopeNone, "lower_case_file_system", "ON"}, {ScopeGlobal, "rpl_semi_sync_master_wait_no_slave", ""}, - {ScopeGlobal | ScopeSession, "group_concat_max_len", "1024"}, + {ScopeGlobal | ScopeSession, GroupConcatMaxLen, "1024"}, {ScopeSession, "pseudo_thread_id", ""}, {ScopeNone, "socket", "/tmp/myssock"}, {ScopeNone, "have_dynamic_loading", "YES"}, @@ -604,8 +607,8 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_buffer_pool_dump_pct", ""}, {ScopeGlobal | ScopeSession, "lc_time_names", "en_US"}, {ScopeGlobal | ScopeSession, "max_statement_time", ""}, - {ScopeGlobal | ScopeSession, "end_markers_in_json", "OFF"}, - {ScopeGlobal, "avoid_temporal_upgrade", "OFF"}, + {ScopeGlobal | ScopeSession, EndMakersInJson, "0"}, + {ScopeGlobal, AvoidTemporalUpgrade, "0"}, {ScopeGlobal, "key_cache_age_threshold", "300"}, {ScopeGlobal, "innodb_status_output", "OFF"}, {ScopeSession, "identity", ""}, @@ -689,6 +692,26 @@ const ( CharsetDatabase = "character_set_database" // CollationDatabase is the name for collation_database system variable. CollationDatabase = "collation_database" + // GeneralLog is the name for 'general_log' system variable. + GeneralLog = "general_log" + // AvoidTemporalUpgrade is the name for 'avoid_temporal_upgrade' system variable. + AvoidTemporalUpgrade = "avoid_temporal_upgrade" + // BigTables is the name for 'big_tables' system variable. + BigTables = "big_tables" + // CheckProxyUsers is the name for 'check_proxy_users' system variable. + CheckProxyUsers = "check_proxy_users" + // CoreFile is the name for 'core_file' system variable. + CoreFile = "core_file" + // DefaultWeekFormat is the name for 'default_week_format' system variable. + DefaultWeekFormat = "default_week_format" + // GroupConcatMaxLen is the name for 'group_concat_max_len' system variable. + GroupConcatMaxLen = "group_concat_max_len" + // DelayKeyWrite is the name for 'delay_key_write' system variable. + DelayKeyWrite = "delay_key_write" + // EndMakersInJson is the name for 'end_markers_in_json' system variable. + EndMakersInJson = "end_markers_in_json" + // SQLLogBin is the name for 'sql_log_bin' system variable. + SQLLogBin = "sql_log_bin" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 3d1c998822131..cf1d3f485a411 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -168,7 +168,44 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { // ValidateSetSystemVar checks if system variable satisfies specific restriction. func ValidateSetSystemVar(name string, value string) (string, error, error) { switch name { - // opt + case DefaultWeekFormat: + val, err := strconv.Atoi(value) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 7 { + return "7", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case GroupConcatMaxLen: + val, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 4 { + return "4", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 18446744073709551615 { + return "18446744073709551615", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case DelayKeyWrite: + if strings.EqualFold(value, "ON") || value == "1" { + return "ON", nil, nil + } else if strings.EqualFold(value, "OFF") || value == "0" { + return "OFF", nil, nil + } else if strings.EqualFold(value, "ALL") || value == "2" { + return "ALL", nil, nil + } + return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJson, SQLLogBin: + if strings.EqualFold(value, "ON") || value == "1" { + return "1", nil, nil + } else if strings.EqualFold(value, "OFF") || value == "0" { + return "0", nil, nil + } + return value, nil, ErrWrongValueForVar.GenByArgs(name, value) case AutocommitVar, TiDBImportingData, TiDBSkipUTF8Check, TiDBOptAggPushDown, TiDBOptInSubqUnFolding, TiDBEnableTablePartition, TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, From ba9834fa355d40c9557164c21ff7c159f6f16e30 Mon Sep 17 00:00:00 2001 From: spongedc Date: Sat, 21 Jul 2018 15:25:17 +0800 Subject: [PATCH 04/11] fix spell --- sessionctx/variable/sysvar.go | 6 +++--- sessionctx/variable/varsutil.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index d81275f2c73f6..f98ec3c2a6761 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -607,7 +607,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_buffer_pool_dump_pct", ""}, {ScopeGlobal | ScopeSession, "lc_time_names", "en_US"}, {ScopeGlobal | ScopeSession, "max_statement_time", ""}, - {ScopeGlobal | ScopeSession, EndMakersInJson, "0"}, + {ScopeGlobal | ScopeSession, EndMakersInJSON, "0"}, {ScopeGlobal, AvoidTemporalUpgrade, "0"}, {ScopeGlobal, "key_cache_age_threshold", "300"}, {ScopeGlobal, "innodb_status_output", "OFF"}, @@ -708,8 +708,8 @@ const ( GroupConcatMaxLen = "group_concat_max_len" // DelayKeyWrite is the name for 'delay_key_write' system variable. DelayKeyWrite = "delay_key_write" - // EndMakersInJson is the name for 'end_markers_in_json' system variable. - EndMakersInJson = "end_markers_in_json" + // EndMakersInJSON is the name for 'end_markers_in_json' system variable. + EndMakersInJSON = "end_markers_in_json" // SQLLogBin is the name for 'sql_log_bin' system variable. SQLLogBin = "sql_log_bin" ) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index cf1d3f485a411..35588411295df 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -199,7 +199,7 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { return "ALL", nil, nil } return value, nil, ErrWrongValueForVar.GenByArgs(name, value) - case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJson, SQLLogBin: + case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin: if strings.EqualFold(value, "ON") || value == "1" { return "1", nil, nil } else if strings.EqualFold(value, "OFF") || value == "0" { From 6070034cf513d42f8e71f5aad458ab1d1735a801 Mon Sep 17 00:00:00 2001 From: spongedc Date: Sat, 21 Jul 2018 21:04:34 +0800 Subject: [PATCH 05/11] Take some more sysvars into check --- sessionctx/variable/sysvar.go | 18 +++++++++++--- sessionctx/variable/varsutil.go | 43 ++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index f98ec3c2a6761..900ffdb8fccdd 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -228,7 +228,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_file_format_max", "Antelope"}, {ScopeGlobal | ScopeSession, "debug", ""}, {ScopeGlobal, "log_warnings", "1"}, - {ScopeGlobal, "offline_mode", ""}, + {ScopeGlobal, OfflineMode, "0"}, {ScopeGlobal | ScopeSession, "innodb_strict_mode", "OFF"}, {ScopeGlobal, "innodb_rollback_segments", "128"}, {ScopeGlobal | ScopeSession, "join_buffer_size", "262144"}, @@ -261,7 +261,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "thread_concurrency", "10"}, {ScopeGlobal | ScopeSession, "query_prealloc_size", "8192"}, {ScopeNone, "relay_log_space_limit", "0"}, - {ScopeGlobal | ScopeSession, "max_user_connections", "0"}, + {ScopeGlobal | ScopeSession, MaxUserConnections, "0"}, {ScopeNone, "performance_schema_max_thread_classes", "50"}, {ScopeGlobal, "innodb_api_trx_level", "0"}, {ScopeNone, "disconnect_on_expired_password", "ON"}, @@ -330,7 +330,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "innodb_log_buffer_size", "8388608"}, {ScopeGlobal, "delayed_insert_timeout", "300"}, {ScopeGlobal, "max_relay_log_size", "0"}, - {ScopeGlobal | ScopeSession, "max_sort_length", "1024"}, + {ScopeGlobal | ScopeSession, MaxSortLength, "1024"}, {ScopeNone, "metadata_locks_hash_instances", "8"}, {ScopeGlobal, "ndb_eventbuffer_free_percent", ""}, {ScopeNone, "large_files_support", "ON"}, @@ -470,7 +470,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "lock_wait_timeout", "31536000"}, {ScopeGlobal | ScopeSession, "read_buffer_size", "131072"}, {ScopeNone, "innodb_read_io_threads", "4"}, - {ScopeGlobal | ScopeSession, "max_sp_recursion_depth", "0"}, + {ScopeGlobal | ScopeSession, MaxSpRecursionDepth, "0"}, {ScopeNone, "ignore_builtin_innodb", "OFF"}, {ScopeGlobal, "rpl_semi_sync_master_enabled", ""}, {ScopeGlobal, "slow_query_log_file", "/usr/local/mysql/data/localhost-slow.log"}, @@ -712,6 +712,16 @@ const ( EndMakersInJSON = "end_markers_in_json" // SQLLogBin is the name for 'sql_log_bin' system variable. SQLLogBin = "sql_log_bin" + // MaxSortLength is the name for 'max_sort_length' system variable. + MaxSortLength = "max_sort_length" + // MaxSortLength is the name for 'max_sp_recursion_depth' system variable. + MaxSpRecursionDepth = "max_sp_recursion_depth" + // MaxSortLength is the name for 'max_user_connections' system variable. + MaxUserConnections = "max_user_connections" + // OfflineMode is the name for 'offline_mode' system variable. + OfflineMode = "offline_mode" + // InteractiveTimeout is the name for 'interactive_timeout' system variable. + InteractiveTimeout = "interactive_timeout" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 35588411295df..208c1dc76a195 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -190,6 +190,47 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { if val > 18446744073709551615 { return "18446744073709551615", ErrTruncatedWrongValue.GenByArgs(name, value), nil } + case MaxUserConnections: + val, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 4294967295 { + return "4294967295", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case MaxSortLength: + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 4 { + return "4", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 8388608 { + return "8388608", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case MaxSpRecursionDepth: + val, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 255 { + return "255", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case InteractiveTimeout: + val, err := strconv.Atoi(value) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 1 { + return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } case DelayKeyWrite: if strings.EqualFold(value, "ON") || value == "1" { return "ON", nil, nil @@ -199,7 +240,7 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { return "ALL", nil, nil } return value, nil, ErrWrongValueForVar.GenByArgs(name, value) - case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin: + case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode: if strings.EqualFold(value, "ON") || value == "1" { return "1", nil, nil } else if strings.EqualFold(value, "OFF") || value == "0" { From fb7929b3b961b45bddc5eb9142f62ef4bae4e555 Mon Sep 17 00:00:00 2001 From: spongedc Date: Sat, 21 Jul 2018 22:09:57 +0800 Subject: [PATCH 06/11] fix spell --- sessionctx/variable/sysvar.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 900ffdb8fccdd..e8742c13ba1d6 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -714,9 +714,9 @@ const ( SQLLogBin = "sql_log_bin" // MaxSortLength is the name for 'max_sort_length' system variable. MaxSortLength = "max_sort_length" - // MaxSortLength is the name for 'max_sp_recursion_depth' system variable. + // MaxSpRecursionDepth is the name for 'max_sp_recursion_depth' system variable. MaxSpRecursionDepth = "max_sp_recursion_depth" - // MaxSortLength is the name for 'max_user_connections' system variable. + // MaxUserConnections is the name for 'max_user_connections' system variable. MaxUserConnections = "max_user_connections" // OfflineMode is the name for 'offline_mode' system variable. OfflineMode = "offline_mode" From fba76e377b674f44e561db480626e983fbceae54 Mon Sep 17 00:00:00 2001 From: spongedc Date: Sun, 22 Jul 2018 23:16:07 +0800 Subject: [PATCH 07/11] 1. support set default 2. add more checks --- executor/set_test.go | 16 +++++-- sessionctx/variable/sysvar.go | 44 +++++++++++++----- sessionctx/variable/varsutil.go | 81 ++++++++++++++++++++++++++------- 3 files changed, 109 insertions(+), 32 deletions(-) diff --git a/executor/set_test.go b/executor/set_test.go index 780d5a2d7dcba..ab7d12e66d8f4 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -90,16 +90,16 @@ func (s *testSuite) TestSetVar(c *C) { // Set default // {ScopeGlobal | ScopeSession, "low_priority_updates", "OFF"}, // For global var - tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("OFF")) + tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("0")) tk.MustExec(`set @@global.low_priority_updates="ON";`) - tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("ON")) + tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("1")) tk.MustExec(`set @@global.low_priority_updates=DEFAULT;`) // It will be set to compiled-in default value. - tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("OFF")) + tk.MustQuery(`select @@global.low_priority_updates;`).Check(testkit.Rows("0")) // For session - tk.MustQuery(`select @@session.low_priority_updates;`).Check(testkit.Rows("OFF")) + tk.MustQuery(`select @@session.low_priority_updates;`).Check(testkit.Rows("0")) tk.MustExec(`set @@global.low_priority_updates="ON";`) tk.MustExec(`set @@session.low_priority_updates=DEFAULT;`) // It will be set to global var value. - tk.MustQuery(`select @@session.low_priority_updates;`).Check(testkit.Rows("ON")) + tk.MustQuery(`select @@session.low_priority_updates;`).Check(testkit.Rows("1")) // For mysql jdbc driver issue. tk.MustQuery(`select @@session.tx_read_only;`).Check(testkit.Rows("0")) @@ -300,4 +300,10 @@ func (s *testSuite) TestValidateSetVar(c *C) { tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect default_week_format value: '9'")) result = tk.MustQuery("select @@default_week_format;") result.Check(testkit.Rows("7")) + + _, err = tk.Exec("set @@error_count = 0") + c.Assert(terror.ErrorEqual(err, variable.ErrReadOnly), IsTrue) + + _, err = tk.Exec("set @@warning_count = 0") + c.Assert(terror.ErrorEqual(err, variable.ErrReadOnly), IsTrue) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index e8742c13ba1d6..54bc6c7d83fb6 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -110,16 +110,16 @@ func boolToIntStr(b bool) string { // we only support MySQL now var defaultSysVars = []*SysVar{ {ScopeGlobal, "gtid_mode", "OFF"}, - {ScopeGlobal, "flush_time", "0"}, - {ScopeSession, "pseudo_slave_mode", ""}, + {ScopeGlobal, FlushTime, "0"}, + {ScopeSession, PseudoSlaveMode, ""}, {ScopeNone, "performance_schema_max_mutex_classes", "200"}, - {ScopeGlobal | ScopeSession, "low_priority_updates", "OFF"}, - {ScopeGlobal | ScopeSession, "session_track_gtids", ""}, + {ScopeGlobal | ScopeSession, LowPriorityUpdates, "0"}, + {ScopeGlobal | ScopeSession, SessionTrackGtids, "OFF"}, {ScopeGlobal | ScopeSession, "ndbinfo_max_rows", ""}, {ScopeGlobal | ScopeSession, "ndb_index_stat_option", ""}, - {ScopeGlobal | ScopeSession, "old_passwords", "0"}, + {ScopeGlobal | ScopeSession, OldPasswords, "0"}, {ScopeNone, "innodb_version", "5.6.25"}, - {ScopeGlobal, "max_connections", "151"}, + {ScopeGlobal, MaxConnections, "151"}, {ScopeGlobal | ScopeSession, BigTables, "0"}, {ScopeNone, "skip_external_locking", "ON"}, {ScopeGlobal, "slave_pending_jobs_size_max", "16777216"}, @@ -146,7 +146,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "lc_messages_dir", "/usr/local/mysql-5.6.25-osx10.8-x86_64/share/"}, {ScopeGlobal, "ft_boolean_syntax", "+ -><()~*:\"\"&|"}, {ScopeGlobal, "table_definition_cache", "1400"}, - {ScopeNone, "skip_name_resolve", "OFF"}, + {ScopeNone, SkipNameResolve, "0"}, {ScopeNone, "performance_schema_max_file_handles", "32768"}, {ScopeSession, "transaction_allow_batching", ""}, {ScopeGlobal | ScopeSession, SQLModeVar, mysql.DefaultSQLMode}, @@ -220,7 +220,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "innodb_autoinc_lock_mode", "1"}, {ScopeGlobal, "slave_net_timeout", "3600"}, {ScopeGlobal, "key_buffer_size", "8388608"}, - {ScopeGlobal | ScopeSession, "foreign_key_checks", "ON"}, + {ScopeGlobal | ScopeSession, ForeignKeyChecks, "1"}, {ScopeGlobal, "host_cache_size", "279"}, {ScopeGlobal, DelayKeyWrite, "ON"}, {ScopeNone, "metadata_locks_cache_size", "1024"}, @@ -600,7 +600,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "innodb_undo_directory", "."}, {ScopeNone, "bind_address", "*"}, {ScopeGlobal, "innodb_sync_spin_loops", "30"}, - {ScopeGlobal | ScopeSession, "sql_safe_updates", "OFF"}, + {ScopeGlobal | ScopeSession, SQLSafeUpdates, "0"}, {ScopeNone, "tmpdir", "/var/tmp/"}, {ScopeGlobal, "innodb_thread_concurrency", "0"}, {ScopeGlobal, "slave_allow_batching", "OFF"}, @@ -615,8 +615,8 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, "min_examined_row_limit", "0"}, {ScopeGlobal, "sync_frm", "ON"}, {ScopeGlobal, "innodb_online_alter_log_max_size", "134217728"}, - {ScopeSession, "warning_count", "0"}, - {ScopeSession, "error_count", "0"}, + {ScopeSession, WarningCount, "0"}, + {ScopeSession, ErrorCount, "0"}, /* TiDB specific variables */ {ScopeSession, TiDBSnapshot, ""}, {ScopeSession, TiDBImportingData, "0"}, @@ -722,6 +722,28 @@ const ( OfflineMode = "offline_mode" // InteractiveTimeout is the name for 'interactive_timeout' system variable. InteractiveTimeout = "interactive_timeout" + // FlushTime is the name for 'flush_time' system variable. + FlushTime = "flush_time" + // PseudoSlaveMode is the name for 'pseudo_slave_mode' system variable. + PseudoSlaveMode = "pseudo_slave_mode" + // LowPriorityUpdates is the name for 'low_priority_updates' system variable. + LowPriorityUpdates = "low_priority_updates" + // SessionTrackGtids is the name for 'session_track_gtids' system variable. + SessionTrackGtids = "session_track_gtids" + // OldPasswords is the name for 'old_passwords' system variable. + OldPasswords = "old_passwords" + // MaxConnections is the name for 'max_connections' system variable. + MaxConnections = "max_connections" + // SkipNameResolve is the name for 'skip_name_resolve' system variable. + SkipNameResolve = "skip_name_resolve" + // ForeignKeyChecks is the name for 'foreign_key_checks' system variable. + ForeignKeyChecks = "foreign_key_checks" + // SQLSafeUpdates is the name for 'sql_safe_updates' system variable. + SQLSafeUpdates = "sql_safe_updates" + // WarningCount is the name for 'warning_count' system variable. + WarningCount = "warning_count" + // ErrorCount is the name for 'error_count' system variable. + ErrorCount = "error_count" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 208c1dc76a195..ae6263238f92d 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -22,6 +22,7 @@ import ( "time" "github.com/juju/errors" + "github.com/pingcap/sessionctx/variable" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/types" @@ -167,6 +168,12 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { // ValidateSetSystemVar checks if system variable satisfies specific restriction. func ValidateSetSystemVar(name string, value string) (string, error, error) { + if strings.EqualFold(value, "DEFAULT") { + if val := variable.GetSysVar(name); val != nil { + return val.Value, nil, nil + } + return value, nil, nil + } switch name { case DefaultWeekFormat: val, err := strconv.Atoi(value) @@ -179,6 +186,23 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { if val > 7 { return "7", ErrTruncatedWrongValue.GenByArgs(name, value), nil } + case DelayKeyWrite: + if strings.EqualFold(value, "ON") || value == "1" { + return "ON", nil, nil + } else if strings.EqualFold(value, "OFF") || value == "0" { + return "OFF", nil, nil + } else if strings.EqualFold(value, "ALL") || value == "2" { + return "ALL", nil, nil + } + return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + case FlushTime: + val, err := strconv.Atoi(value) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } case GroupConcatMaxLen: val, err := strconv.ParseUint(value, 10, 64) if err != nil { @@ -190,16 +214,24 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { if val > 18446744073709551615 { return "18446744073709551615", ErrTruncatedWrongValue.GenByArgs(name, value), nil } - case MaxUserConnections: - val, err := strconv.ParseUint(value, 10, 64) + case InteractiveTimeout: + val, err := strconv.Atoi(value) if err != nil { return value, nil, ErrWrongTypeForVar.GenByArgs(name) } - if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + if val < 1 { + return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil } - if val > 4294967295 { - return "4294967295", ErrTruncatedWrongValue.GenByArgs(name, value), nil + case MaxConnections: + val, err := strconv.Atoi(value) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 1 { + return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 100000 { + return "100000", ErrTruncatedWrongValue.GenByArgs(name, value), nil } case MaxSortLength: val, err := strconv.ParseInt(value, 10, 64) @@ -223,24 +255,41 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { if val > 255 { return "255", ErrTruncatedWrongValue.GenByArgs(name, value), nil } - case InteractiveTimeout: + case OldPasswords: val, err := strconv.Atoi(value) if err != nil { return value, nil, ErrWrongTypeForVar.GenByArgs(name) } - if val < 1 { - return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil } - case DelayKeyWrite: - if strings.EqualFold(value, "ON") || value == "1" { - return "ON", nil, nil - } else if strings.EqualFold(value, "OFF") || value == "0" { + if val > 2 { + return "2", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case MaxUserConnections: + val, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return value, nil, ErrWrongTypeForVar.GenByArgs(name) + } + if val < 0 { + return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + if val > 4294967295 { + return "4294967295", ErrTruncatedWrongValue.GenByArgs(name, value), nil + } + case SessionTrackGtids: + if strings.EqualFold(value, "OFF") || value == "0" { return "OFF", nil, nil - } else if strings.EqualFold(value, "ALL") || value == "2" { - return "ALL", nil, nil + } else if strings.EqualFold(value, "OWN_GTID") || value == "1" { + return "OWN_GTID", nil, nil + } else if strings.EqualFold(value, "ALL_GTIDS") || value == "2" { + return "ALL_GTIDS", nil, nil } return value, nil, ErrWrongValueForVar.GenByArgs(name, value) - case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode: + case WarningCount, ErrorCount: + return value, nil, ErrReadOnly.GenByArgs(name) + case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode, + PseudoSlaveMode, LowPriorityUpdates, SkipNameResolve, ForeignKeyChecks, SQLSafeUpdates: if strings.EqualFold(value, "ON") || value == "1" { return "1", nil, nil } else if strings.EqualFold(value, "OFF") || value == "0" { From 71ff2ac7af9c41a63c38e78e5259d2bf507ad6ef Mon Sep 17 00:00:00 2001 From: spongedc Date: Mon, 23 Jul 2018 11:23:11 +0800 Subject: [PATCH 08/11] Fix build --- sessionctx/variable/varsutil.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index ae6263238f92d..7178155bcf6d6 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -22,7 +22,6 @@ import ( "time" "github.com/juju/errors" - "github.com/pingcap/sessionctx/variable" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/types" @@ -169,7 +168,7 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { // ValidateSetSystemVar checks if system variable satisfies specific restriction. func ValidateSetSystemVar(name string, value string) (string, error, error) { if strings.EqualFold(value, "DEFAULT") { - if val := variable.GetSysVar(name); val != nil { + if val := GetSysVar(name); val != nil { return val.Value, nil, nil } return value, nil, nil From 3273f75627db1ecd49492f0a7ace70d4bbc0af2a Mon Sep 17 00:00:00 2001 From: spongedc Date: Mon, 23 Jul 2018 20:46:41 +0800 Subject: [PATCH 09/11] refine ValidateSetSystemVar --- session/session.go | 7 +- sessionctx/variable/varsutil.go | 120 ++++++++++++++++++-------------- 2 files changed, 69 insertions(+), 58 deletions(-) diff --git a/session/session.go b/session/session.go index ff3ec7be54e2d..311953b787f87 100644 --- a/session/session.go +++ b/session/session.go @@ -697,14 +697,11 @@ func (s *session) SetGlobalSysVar(name, value string) error { } } var sVal string - var err, warn error - sVal, warn, err = variable.ValidateSetSystemVar(name, value) + var err error + sVal, err = variable.ValidateSetSystemVar(s.sessionVars, name, value) if err != nil { return errors.Trace(err) } - if warn != nil { - s.sessionVars.StmtCtx.AppendWarning(warn) - } sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`, mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), sVal) _, _, err = s.ExecRestrictedSQL(s, sql) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 7178155bcf6d6..5c14318613a34 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -131,18 +131,15 @@ func SetSessionSystemVar(vars *SessionVars, name string, value types.Datum) erro return vars.deleteSystemVar(name) } var sVal string - var err, warn error + var err error sVal, err = value.ToString() if err != nil { return errors.Trace(err) } - sVal, warn, err = ValidateSetSystemVar(name, sVal) + sVal, err = ValidateSetSystemVar(vars, name, sVal) if err != nil { return errors.Trace(err) } - if warn != nil { - vars.StmtCtx.AppendWarning(warn) - } return vars.SetSystemVar(name, sVal) } @@ -166,143 +163,160 @@ func ValidateGetSystemVar(name string, isGlobal bool) error { } // ValidateSetSystemVar checks if system variable satisfies specific restriction. -func ValidateSetSystemVar(name string, value string) (string, error, error) { +func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, error) { if strings.EqualFold(value, "DEFAULT") { if val := GetSysVar(name); val != nil { - return val.Value, nil, nil + return val.Value, nil } - return value, nil, nil + // should never happen + panic(fmt.Sprintf("Error happened when ValidateSetSystemVar. Invalid system variable: %s", name)) } switch name { case DefaultWeekFormat: val, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "0", nil } if val > 7 { - return "7", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "7", nil } case DelayKeyWrite: if strings.EqualFold(value, "ON") || value == "1" { - return "ON", nil, nil + return "ON", nil } else if strings.EqualFold(value, "OFF") || value == "0" { - return "OFF", nil, nil + return "OFF", nil } else if strings.EqualFold(value, "ALL") || value == "2" { - return "ALL", nil, nil + return "ALL", nil } - return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + return value, ErrWrongValueForVar.GenByArgs(name, value) case FlushTime: val, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "0", nil } case GroupConcatMaxLen: val, err := strconv.ParseUint(value, 10, 64) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 4 { - return "4", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "4", nil } if val > 18446744073709551615 { - return "18446744073709551615", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "18446744073709551615", nil } case InteractiveTimeout: val, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 1 { - return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "1", nil } case MaxConnections: val, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 1 { - return "1", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "1", nil } if val > 100000 { - return "100000", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "100000", nil } case MaxSortLength: val, err := strconv.ParseInt(value, 10, 64) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 4 { - return "4", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "4", nil } if val > 8388608 { - return "8388608", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "8388608", nil } case MaxSpRecursionDepth: val, err := strconv.ParseInt(value, 10, 64) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "0", nil } if val > 255 { - return "255", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "255", nil } case OldPasswords: val, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "0", nil } if val > 2 { - return "2", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "2", nil } case MaxUserConnections: val, err := strconv.ParseUint(value, 10, 64) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if val < 0 { - return "0", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "0", nil } if val > 4294967295 { - return "4294967295", ErrTruncatedWrongValue.GenByArgs(name, value), nil + vars.StmtCtx.AppendWarning(ErrTruncatedWrongValue.GenByArgs(name, value)) + return "4294967295", nil } case SessionTrackGtids: if strings.EqualFold(value, "OFF") || value == "0" { - return "OFF", nil, nil + return "OFF", nil } else if strings.EqualFold(value, "OWN_GTID") || value == "1" { - return "OWN_GTID", nil, nil + return "OWN_GTID", nil } else if strings.EqualFold(value, "ALL_GTIDS") || value == "2" { - return "ALL_GTIDS", nil, nil + return "ALL_GTIDS", nil } - return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + return value, ErrWrongValueForVar.GenByArgs(name, value) case WarningCount, ErrorCount: - return value, nil, ErrReadOnly.GenByArgs(name) + return value, ErrReadOnly.GenByArgs(name) case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode, PseudoSlaveMode, LowPriorityUpdates, SkipNameResolve, ForeignKeyChecks, SQLSafeUpdates: if strings.EqualFold(value, "ON") || value == "1" { - return "1", nil, nil + return "1", nil } else if strings.EqualFold(value, "OFF") || value == "0" { - return "0", nil, nil + return "0", nil } - return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + return value, ErrWrongValueForVar.GenByArgs(name, value) case AutocommitVar, TiDBImportingData, TiDBSkipUTF8Check, TiDBOptAggPushDown, TiDBOptInSubqUnFolding, TiDBEnableTablePartition, TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, TiDBBatchDelete: if strings.EqualFold(value, "ON") || value == "1" || strings.EqualFold(value, "OFF") || value == "0" { - return value, nil, nil + return value, nil } - return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + return value, ErrWrongValueForVar.GenByArgs(name, value) case TiDBIndexLookupConcurrency, TiDBIndexLookupJoinConcurrency, TiDBIndexJoinBatchSize, TiDBIndexLookupSize, TiDBHashJoinConcurrency, @@ -315,12 +329,12 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { TiDBGeneralLog: v, err := strconv.Atoi(value) if err != nil { - return value, nil, ErrWrongTypeForVar.GenByArgs(name) + return value, ErrWrongTypeForVar.GenByArgs(name) } if v <= 0 { - return value, nil, ErrWrongValueForVar.GenByArgs(name, value) + return value, ErrWrongValueForVar.GenByArgs(name, value) } - return value, nil, nil + return value, nil case TiDBProjectionConcurrency, TIDBMemQuotaQuery, TIDBMemQuotaHashJoin, @@ -333,11 +347,11 @@ func ValidateSetSystemVar(name string, value string) (string, error, error) { TiDBRetryLimit: _, err := strconv.ParseInt(value, 10, 64) if err != nil { - return value, nil, ErrWrongValueForVar.GenByArgs(name) + return value, ErrWrongValueForVar.GenByArgs(name) } - return value, nil, nil + return value, nil } - return value, nil, nil + return value, nil } // TiDBOptOn could be used for all tidb session variable options, we use "ON"/1 to turn on those options. From 57961c6bd02a7428119e11c296c594e2cbb3ac1e Mon Sep 17 00:00:00 2001 From: spongedc Date: Tue, 24 Jul 2018 11:12:32 +0800 Subject: [PATCH 10/11] return UnknownSystemVar error if sysvar not found --- sessionctx/variable/varsutil.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 5c14318613a34..f15fe20952500 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -168,8 +168,7 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, if val := GetSysVar(name); val != nil { return val.Value, nil } - // should never happen - panic(fmt.Sprintf("Error happened when ValidateSetSystemVar. Invalid system variable: %s", name)) + return value, UnknownSystemVar.GenByArgs(name) } switch name { case DefaultWeekFormat: From ab814bd7a5d3752f54874ef97826a8ec58968099 Mon Sep 17 00:00:00 2001 From: spongedc Date: Tue, 24 Jul 2018 16:15:40 +0800 Subject: [PATCH 11/11] Add check for time_zone --- executor/set_test.go | 5 +++++ sessionctx/variable/varsutil.go | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/executor/set_test.go b/executor/set_test.go index ab7d12e66d8f4..bce2129bc7f6b 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -306,4 +306,9 @@ func (s *testSuite) TestValidateSetVar(c *C) { _, err = tk.Exec("set @@warning_count = 0") c.Assert(terror.ErrorEqual(err, variable.ErrReadOnly), IsTrue) + + tk.MustExec("set time_zone='SySTeM'") + result = tk.MustQuery("select @@time_zone;") + result.Check(testkit.Rows("SYSTEM")) + } diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index f15fe20952500..3ad457c76838a 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -298,6 +298,11 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return "ALL_GTIDS", nil } return value, ErrWrongValueForVar.GenByArgs(name, value) + case TimeZone: + if strings.EqualFold(value, "SYSTEM") { + return "SYSTEM", nil + } + return value, nil case WarningCount, ErrorCount: return value, ErrReadOnly.GenByArgs(name) case GeneralLog, AvoidTemporalUpgrade, BigTables, CheckProxyUsers, CoreFile, EndMakersInJSON, SQLLogBin, OfflineMode, @@ -375,7 +380,7 @@ func tidbOptInt64(opt string, defaultVal int64) int64 { } func parseTimeZone(s string) (*time.Location, error) { - if s == "SYSTEM" { + if strings.EqualFold(s, "SYSTEM") { // TODO: Support global time_zone variable, it should be set to global time_zone value. return time.Local, nil }