Skip to content

Commit

Permalink
session, sessionctx/variable: fix validation recursion bug (#30293)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored Dec 2, 2021
1 parent 2242f9c commit cbe5240
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
8 changes: 3 additions & 5 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1115,11 +1115,9 @@ func (s *session) GetGlobalSysVar(name string) (string, error) {
return sv.Value, nil
}
}
// It might have been written from an earlier TiDB version, so we should run the validation function
// To normalize the value to be safe for this version of TiDB. This also happens for session scoped
// variables in loadCommonGlobalVariablesIfNeeded -> SetSystemVarWithRelaxedValidation
sysVar = sv.ValidateWithRelaxedValidation(s.GetSessionVars(), sysVar, variable.ScopeGlobal)
return sysVar, nil
// It might have been written from an earlier TiDB version, so we should do type validation
// See https://github.com/pingcap/tidb/issues/30255 for why we don't do full validation.
return sv.ValidateFromType(s.GetSessionVars(), sysVar, variable.ScopeGlobal)
}

// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface.
Expand Down
16 changes: 16 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,22 @@ func (s *testSchemaSerialSuite) TestLoadSchemaFailed(c *C) {
tk2.MustExec("commit")
}

func (s *testSchemaSerialSuite) TestValidationRecursion(c *C) {
// We have to expect that validation functions will call GlobalVarsAccessor.GetGlobalSysVar().
// This tests for a regression where GetGlobalSysVar() can not safely call the validation
// function because it might cause infinite recursion.
// See: https://github.com/pingcap/tidb/issues/30255
sv := variable.SysVar{Scope: variable.ScopeGlobal, Name: "mynewsysvar", Value: "test", Validation: func(vars *variable.SessionVars, normalizedValue string, originalValue string, scope variable.ScopeFlag) (string, error) {
return vars.GlobalVarsAccessor.GetGlobalSysVar("mynewsysvar")
}}
variable.RegisterSysVar(&sv)

tk := testkit.NewTestKitWithInit(c, s.store)
val, err := sv.Validate(tk.Se.GetSessionVars(), "test2", variable.ScopeGlobal)
c.Assert(err, IsNil)
c.Assert(val, Equals, "test")
}

func (s *testSchemaSerialSuite) TestSchemaCheckerSQL(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk1 := testkit.NewTestKitWithInit(c, s.store)
Expand Down
10 changes: 5 additions & 5 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (sv *SysVar) Validate(vars *SessionVars, value string, scope ScopeFlag) (st
}
// Normalize the value and apply validation based on type.
// i.e. TypeBool converts 1/on/ON to ON.
normalizedValue, err := sv.validateFromType(vars, value, scope)
normalizedValue, err := sv.ValidateFromType(vars, value, scope)
if err != nil {
return normalizedValue, err
}
Expand All @@ -276,8 +276,8 @@ func (sv *SysVar) Validate(vars *SessionVars, value string, scope ScopeFlag) (st
return normalizedValue, nil
}

// validateFromType provides automatic validation based on the SysVar's type
func (sv *SysVar) validateFromType(vars *SessionVars, value string, scope ScopeFlag) (string, error) {
// ValidateFromType provides automatic validation based on the SysVar's type
func (sv *SysVar) ValidateFromType(vars *SessionVars, value string, scope ScopeFlag) (string, error) {
// The string "DEFAULT" is a special keyword in MySQL, which restores
// the compiled sysvar value. In which case we can skip further validation.
if strings.EqualFold(value, "DEFAULT") {
Expand Down Expand Up @@ -329,9 +329,9 @@ func (sv *SysVar) validateScope(scope ScopeFlag) error {
func (sv *SysVar) ValidateWithRelaxedValidation(vars *SessionVars, value string, scope ScopeFlag) string {
warns := vars.StmtCtx.GetWarnings()
defer func() {
vars.StmtCtx.SetWarnings(warns) // RelaxedVaidation = trim warnings too.
vars.StmtCtx.SetWarnings(warns) // RelaxedValidation = trim warnings too.
}()
normalizedValue, err := sv.validateFromType(vars, value, scope)
normalizedValue, err := sv.ValidateFromType(vars, value, scope)
if err != nil {
return normalizedValue
}
Expand Down

0 comments on commit cbe5240

Please sign in to comment.