Skip to content

Commit

Permalink
feature: 添加参数disable_types可灵活配置不允许的数据类型(#400,#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchuanchuan committed Jan 10, 2022
1 parent ddfe5a1 commit bf892ad
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 50 deletions.
4 changes: 3 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ type Inc struct {
CheckReadOnly bool `toml:"check_read_only" json:"check_read_only"`

// 连接服务器的默认字符集,默认值为utf8mb4
DefaultCharset string `toml:"default_charset" json:"default_charset"`
DefaultCharset string `toml:"default_charset" json:"default_charset"`
// 禁用数据库类型,多个时以逗号分隔.该参数优先级低于enable_blob_type/enable_enum_set_bit等参数
DisableTypes string `toml:"disable_types" json:"disable_types"`
EnableAlterDatabase bool `toml:"enable_alter_database" json:"enable_alter_database"`
// 允许执行任意语法类型.该设置有安全要求,仅支持配置文件方式设置
EnableAnyStatement bool `toml:"enable_any_statement" json:"enable_any_statement"`
Expand Down
1 change: 1 addition & 0 deletions session/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (s *session) init() {

s.backupDBCacheList = make(map[string]bool)
s.backupTableCacheList = make(map[string]bool)
s.disableTypes = make(map[string]struct{})

s.inc = config.GetGlobalConfig().Inc
s.osc = config.GetGlobalConfig().Osc
Expand Down
6 changes: 2 additions & 4 deletions session/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ var ErrorsDefault = map[ErrorCode]string{
ER_ID_IS_UPER: "Identifier is not allowed to been upper-case.",
ErrUnknownCharset: "Unknown charset: '%s'.",
ER_UNKNOWN_COLLATION: "Unknown collation: '%s'.",
ER_INVALID_DATA_TYPE: "Not supported data type on field: '%s'.",
ER_INVALID_DATA_TYPE: "Not supported data type on field: '%s'(%s).",
ER_NOT_ALLOWED_NULLABLE: "Column '%s' in table '%s' is not allowed to been nullable.",
ER_DUP_FIELDNAME: "Duplicate column name '%s'.",
ER_WRONG_COLUMN_NAME: "Incorrect column name '%s'.",
Expand Down Expand Up @@ -434,7 +434,7 @@ var ErrorsChinese = map[ErrorCode]string{
ER_ID_IS_UPER: "标识符不允许大写.",
ErrUnknownCharset: "未知的字符集: '%s'.",
ER_UNKNOWN_COLLATION: "未知的排序规则: '%s'.",
ER_INVALID_DATA_TYPE: "列 '%s' 数据类型不支持.",
ER_INVALID_DATA_TYPE: "列 '%s' 数据类型(%s)不支持.",
ER_NOT_ALLOWED_NULLABLE: "列 '%s' 不允许为null(表 '%s').",
ER_DUP_FIELDNAME: "重复的列名: '%s'.",
ER_WRONG_COLUMN_NAME: "不正确的列名: '%s'.",
Expand Down Expand Up @@ -1200,10 +1200,8 @@ func TestCheckAuditSetting(cnf *config.Config) {

if !cnf.Inc.EnableEnumSetBit {
cnf.IncLevel.ER_USE_ENUM = int8(GetErrorLevel(ER_USE_ENUM))
cnf.IncLevel.ER_INVALID_DATA_TYPE = int8(GetErrorLevel(ER_INVALID_DATA_TYPE))
} else {
cnf.IncLevel.ER_USE_ENUM = 0
cnf.IncLevel.ER_INVALID_DATA_TYPE = 0
}

if cnf.Inc.CheckIndexPrefix {
Expand Down
4 changes: 4 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ type Session interface {

// 用以测试
GetAlterTablePostPart(sql string, isPtOSC bool) string
InitDisableTypes()

LoadOptions(opt SourceOptions) error
Audit(ctx context.Context, sql string) ([]Record, error)
Expand Down Expand Up @@ -263,6 +264,9 @@ type session struct {

// masking 语法树解析功能
maskingFields []MaskingFieldInfo

// 统一处理禁用的数据类型
disableTypes map[string]struct{}
}

func (s *session) getMembufCap() int {
Expand Down
73 changes: 57 additions & 16 deletions session/session_inception.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func (s *session) executeInc(ctx context.Context, sql string) (recordSets []sqle
s.sqlFingerprint = make(map[string]*Record, 64)
}

s.initDisableTypes()
continue
case *ast.InceptionCommitStmt:

Expand Down Expand Up @@ -645,6 +646,8 @@ func (s *session) processCommand(ctx context.Context, stmtNode ast.StmtNode,
_, err := s.executeInceptionSet(node, currentSql)
if err != nil {
s.appendErrorMessage(err.Error())
} else {
s.initDisableTypes()
}
} else {
return s.executeInceptionSet(node, currentSql)
Expand Down Expand Up @@ -3971,18 +3974,19 @@ func (s *session) checkVarcharLength(t *TableInfo, colDef *ast.ColumnDef) {
}
}
}

func (s *session) mysqlCheckField(t *TableInfo, field *ast.ColumnDef, alterTableType ast.AlterTableType) {
log.Debug("mysqlCheckField")

tableName := t.Name
if !s.inc.EnableEnumSetBit && (field.Tp.Tp == mysql.TypeEnum ||
field.Tp.Tp == mysql.TypeSet ||
field.Tp.Tp == mysql.TypeBit) {
s.appendErrorNo(ER_INVALID_DATA_TYPE, field.Name.Name)
}

if field.Tp.Tp == mysql.TypeTimestamp && !s.inc.EnableTimeStampType {
s.appendErrorNo(ER_INVALID_DATA_TYPE, field.Name.Name)
if len(s.disableTypes) > 0 {
fieldType := types.TypeToStr(field.Tp.Tp, field.Tp.Charset)
for typeStr := range s.disableTypes {
if typeStr == fieldType {
s.appendErrorNo(ER_INVALID_DATA_TYPE, field.Name.Name, typeStr)
break
}
}
}

if field.Tp.Tp == mysql.TypeString && (s.inc.MaxCharLength > 0 && field.Tp.Flen > int(s.inc.MaxCharLength)) {
Expand Down Expand Up @@ -4120,14 +4124,15 @@ func (s *session) mysqlCheckField(t *TableInfo, field *ast.ColumnDef, alterTable
}
//是否使用 text\blob\json 字段类型
//当EnableNullable=false,不强制text\blob\json使用NOT NULL
if types.IsTypeBlob(field.Tp.Tp) {
s.appendErrorNo(ER_USE_TEXT_OR_BLOB, field.Name.Name)
} else if field.Tp.Tp == mysql.TypeJSON {
s.appendErrorNo(ErrJsonTypeSupport, field.Name.Name)
} else {
if !notNullFlag && !hasGenerated {
s.appendErrorNo(ER_NOT_ALLOWED_NULLABLE, field.Name.Name, tableName)
}

// 类型限制统一由disableTypes处理
// if types.IsTypeBlob(field.Tp.Tp) {
// s.appendErrorNo(ER_USE_TEXT_OR_BLOB, field.Name.Name)
// } else if field.Tp.Tp == mysql.TypeJSON {
// s.appendErrorNo(ErrJsonTypeSupport, field.Name.Name)
// }
if !notNullFlag && !hasGenerated {
s.appendErrorNo(ER_NOT_ALLOWED_NULLABLE, field.Name.Name, tableName)
}

// 审核所有指定了charset或collate的字段
Expand Down Expand Up @@ -8228,3 +8233,39 @@ func (s *session) checkVaildWhere(expr ast.ExprNode) bool {
}
return true
}

func (s *session) initDisableTypes() {
log.Debug("initDisableTypes")
s.disableTypes = make(map[string]struct{})
if !s.inc.EnableBlobType {
s.disableTypes["tinytext"] = struct{}{}
s.disableTypes["mediumtext"] = struct{}{}
s.disableTypes["longtext"] = struct{}{}
s.disableTypes["text"] = struct{}{}
s.disableTypes["tinyblob"] = struct{}{}
s.disableTypes["mediumblob"] = struct{}{}
s.disableTypes["longblob"] = struct{}{}
s.disableTypes["blob"] = struct{}{}
}
if !s.inc.EnableTimeStampType {
s.disableTypes["timestamp"] = struct{}{}
}
if !s.inc.EnableJsonType {
s.disableTypes["json"] = struct{}{}
}
if !s.inc.EnableEnumSetBit {
s.disableTypes["enum"] = struct{}{}
s.disableTypes["set"] = struct{}{}
s.disableTypes["bit"] = struct{}{}
}
for _, typeStr := range strings.Split(s.inc.DisableTypes, ",") {
key := strings.ToLower(strings.TrimSpace(typeStr))
if key != "" {
s.disableTypes[key] = struct{}{}
}
}
}

func (s *session) InitDisableTypes() {
s.initDisableTypes()
}
78 changes: 49 additions & 29 deletions session/session_inception_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,20 +381,20 @@ func (s *testSessionIncSuite) TestCreateTable(c *C) {
// 数据类型 警告
sql = "create table t1(id int,c1 bit);"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "bit"))

sql = "create table t1(id int,c1 enum('red', 'blue', 'black'));"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "enum"))

sql = "create table t1(id int,c1 set('red', 'blue', 'black'));"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "set"))

config.GetGlobalConfig().Inc.EnableTimeStampType = false
sql = "create table t1(id int,c1 timestamp);"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "timestamp"))

config.GetGlobalConfig().Inc.EnableTimeStampType = true
sql = "create table t1(id int,c1 timestamp);"
Expand Down Expand Up @@ -460,18 +460,38 @@ func (s *testSessionIncSuite) TestCreateTable(c *C) {

// blob/text字段
config.GetGlobalConfig().Inc.EnableBlobType = false
sql = ("create table t1(id int,c1 blob, c2 text);")
sql = "create table t111(id int,c1 blob, c2 text);"
s.testErrorCode(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c2"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "blob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "text"),
)

config.GetGlobalConfig().Inc.DisableTypes = "blob,text"
sql = "create table t111(id int,c1 blob, c2 text);"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "blob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "text"),
)
config.GetGlobalConfig().Inc.DisableTypes = ""

config.GetGlobalConfig().Inc.EnableBlobType = true
sql = ("create table t1(id int,c1 blob not null);")
s.testErrorCode(c, sql,
session.NewErr(session.ER_TEXT_NOT_NULLABLE_ERROR, "c1", "t1"),
)

// 指定类型禁用
config.GetGlobalConfig().Inc.DisableTypes = "bit"
sql = "create table t1(id int,c1 bit default b'0');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "bit"))
config.GetGlobalConfig().Inc.DisableTypes = ""

config.GetGlobalConfig().Inc.EnableEnumSetBit = false
sql = "create table t1(id int,c1 bit default b'0');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "bit"))

// 检查默认值
config.GetGlobalConfig().Inc.CheckColumnDefaultValue = true
sql = "create table t1(c1 varchar(10));"
Expand Down Expand Up @@ -677,7 +697,7 @@ func (s *testSessionIncSuite) TestCreateTable(c *C) {
config.GetGlobalConfig().Inc.CheckIndexPrefix = false
sql = "create table test_error_code_3(pt text ,primary key (pt));"
s.testErrorCode(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "pt"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "pt", "text"),
session.NewErr(session.ER_TOO_LONG_KEY, "PRIMARY", indexMaxLength))

config.GetGlobalConfig().Inc.EnableBlobType = true
Expand Down Expand Up @@ -1055,15 +1075,15 @@ func (s *testSessionIncSuite) TestAlterTableAddColumn(c *C) {
// 数据类型 警告
sql = "drop table if exists t1;create table t1(id int);alter table t1 add column c2 bit;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "bit"))

sql = "drop table if exists t1;create table t1(id int);alter table t1 add column c2 enum('red', 'blue', 'black');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "enum"))

sql = "drop table if exists t1;create table t1(id int);alter table t1 add column c2 set('red', 'blue', 'black');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "set"))

// char列建议
config.GetGlobalConfig().Inc.MaxCharLength = 100
Expand Down Expand Up @@ -1116,8 +1136,8 @@ func (s *testSessionIncSuite) TestAlterTableAddColumn(c *C) {
config.GetGlobalConfig().Inc.EnableBlobType = false
sql = ("drop table if exists t1;create table t1(id int);alter table t1 add column c1 blob;alter table t1 add column c2 text;")
s.testManyErrors(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c2"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "blob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "text"),
)

config.GetGlobalConfig().Inc.EnableBlobType = true
Expand Down Expand Up @@ -1169,7 +1189,7 @@ func (s *testSessionIncSuite) TestAlterTableAddColumn(c *C) {
config.GetGlobalConfig().Inc.EnableJsonType = false
sql = "drop table if exists t1;create table t1 (c1 int primary key);alter table t1 add c2 json;"
s.testErrorCode(c, sql,
session.NewErr(session.ErrJsonTypeSupport, "c2"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "json"))
}

sql = "drop table if exists t1;create table t1 (id int primary key);alter table t1 add column (c1 int,c2 varchar(20));"
Expand Down Expand Up @@ -1260,15 +1280,15 @@ func (s *testSessionIncSuite) TestAlterTableModifyColumn(c *C) {
// 数据类型 警告
sql = "create table t1(id bit);alter table t1 modify column id bit;"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "id"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "id", "bit"))

sql = "create table t1(id enum('red', 'blue'));alter table t1 modify column id enum('red', 'blue', 'black');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "id"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "id", "enum"))

sql = "create table t1(id set('red'));alter table t1 modify column id set('red', 'blue', 'black');"
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "id"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "id", "set"))

// char列建议
config.GetGlobalConfig().Inc.MaxCharLength = 100
Expand Down Expand Up @@ -1301,10 +1321,10 @@ func (s *testSessionIncSuite) TestAlterTableModifyColumn(c *C) {

// blob/text字段
config.GetGlobalConfig().Inc.EnableBlobType = false
sql = ("create table t1(id int,c1 varchar(10));alter table t1 modify column c1 blob;alter table t1 modify column c1 text;")
config.GetGlobalConfig().Inc.CheckColumnTypeChange = false
sql = ("create table t1(id int,c1 varchar(10));alter table t1 modify column c1 blob;")
s.testManyErrors(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "blob"),
)

config.GetGlobalConfig().Inc.EnableBlobType = true
Expand Down Expand Up @@ -2910,7 +2930,7 @@ func (s *testSessionIncSuite) TestTimestampType(c *C) {
// sql = `create table t4 (id int unsigned not null auto_increment primary key comment 'primary key', a timestamp not null default 0 comment 'a') comment 'test';`
sql = `create table t4 (id int unsigned not null auto_increment primary key comment 'primary key', a timestamp not null comment 'a') comment 'test';`
s.testErrorCode(c, sql,
session.NewErr(session.ER_INVALID_DATA_TYPE, "a"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "a", "timestamp"))
config.GetGlobalConfig().Inc.EnableTimeStampType = true
}

Expand Down Expand Up @@ -3295,21 +3315,21 @@ func (s *testSessionIncSuite) TestBlobAndText(c *C) {
c3 mediumblob,
c4 longblob);`
s.testErrorCode(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c2"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c3"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c4"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "tinyblob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "blob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c3", "mediumblob"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c4", "longblob"))

sql = `create table t2(id int primary key,
c1 tinytext ,
c2 text,
c3 mediumtext,
c4 longtext);`
s.testErrorCode(c, sql,
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c1"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c2"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c3"),
session.NewErr(session.ER_USE_TEXT_OR_BLOB, "c4"))
session.NewErr(session.ER_INVALID_DATA_TYPE, "c1", "tinytext"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c2", "text"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c3", "mediumtext"),
session.NewErr(session.ER_INVALID_DATA_TYPE, "c4", "longtext"))

}

Expand Down

0 comments on commit bf892ad

Please sign in to comment.