Skip to content

Commit

Permalink
remove unnecessary judgments and fix misspellings (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
demoManito authored Nov 11, 2022
1 parent ae06135 commit 4a51687
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 42 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.14

require (
github.com/go-sql-driver/mysql v1.6.0
github.com/jinzhu/now v1.1.5 // indirect
gorm.io/gorm v1.23.8
)

require github.com/jinzhu/now v1.1.5 // indirect
6 changes: 2 additions & 4 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,8 @@ func groupByIndexName(indexList []*Index) map[string][]*Index {
}

func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (string, string) {
if strings.Contains(table, ".") {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
m.DB = m.DB.Table(table)
return m.CurrentDatabase(), table
Expand Down
77 changes: 40 additions & 37 deletions mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/go-sql-driver/mysql"

"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -76,22 +77,20 @@ func (dialector Dialector) NowFunc(n int) func() time.Time {
}

func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NowFunc == nil {
if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}

// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
if config.NowFunc != nil {
return nil
}

if dialector.DefaultDatetimePrecision == nil {
dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
}
// while maintaining the readability of the code, separate the business logic from
// the general part and leave it to the function to do it here.
config.NowFunc = dialector.NowFunc(*dialector.DefaultDatetimePrecision)
return nil
}

func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
ctx := context.Background()

if dialector.DriverName == "" {
dialector.DriverName = "mysql"
}
Expand All @@ -111,7 +110,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {

withReturning := false
if !dialector.Config.SkipInitializeWithVersion {
err = db.ConnPool.QueryRowContext(ctx, "SELECT VERSION()").Scan(&dialector.ServerVersion)
err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
if err != nil {
return err
}
Expand All @@ -121,9 +120,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
dialector.Config.DontSupportNullAsDefaultValue = true
if checkVersion(dialector.ServerVersion, "10.5") {
withReturning = true
}
withReturning = checkVersion(dialector.ServerVersion, "10.5")
} else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
dialector.Config.DontSupportRenameIndex = true
dialector.Config.DontSupportRenameColumn = true
Expand Down Expand Up @@ -176,7 +173,7 @@ const (
ClauseOnConflict = "ON CONFLICT"
// ClauseValues for clause.ClauseBuilder VALUES key
ClauseValues = "VALUES"
// ClauseValues for clause.ClauseBuilder FOR key
// ClauseFor for clause.ClauseBuilder FOR key
ClauseFor = "FOR"
)

Expand Down Expand Up @@ -393,11 +390,11 @@ func (dialector Dialector) getSchemaStringType(field *schema.Field) string {
}

func (dialector Dialector) getSchemaTimeType(field *schema.Field) string {
precision := ""
if !dialector.DisableDatetimePrecision && field.Precision == 0 {
field.Precision = *dialector.DefaultDatetimePrecision
}

var precision string
if field.Precision > 0 {
precision = fmt.Sprintf("(%d)", field.Precision)
}
Expand All @@ -421,27 +418,31 @@ func (dialector Dialector) getSchemaBytesType(field *schema.Field) string {
}

func (dialector Dialector) getSchemaIntAndUnitType(field *schema.Field) string {
sqlType := "bigint"
constraint := func(sqlType string) string {
if field.DataType == schema.Uint {
sqlType += " unsigned"
}
if field.NotNull {
sqlType += " NOT NULL"
}
if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
}
return sqlType
}

switch {
case field.Size <= 8:
sqlType = "tinyint"
return constraint("tinyint")
case field.Size <= 16:
sqlType = "smallint"
return constraint("smallint")
case field.Size <= 24:
sqlType = "mediumint"
return constraint("mediumint")
case field.Size <= 32:
sqlType = "int"
}

if field.DataType == schema.Uint {
sqlType += " unsigned"
}

if field.AutoIncrement {
sqlType += " AUTO_INCREMENT"
return constraint("int")
default:
return constraint("bigint")
}

return sqlType
}

func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
Expand All @@ -462,23 +463,25 @@ func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
return tx.Exec("ROLLBACK TO SAVEPOINT " + name).Error
}

var versionTrimerRegexp = regexp.MustCompile(`^(\d+).*$`)

// checkVersion newer or equal returns true, old returns false
func checkVersion(newVersion, oldVersion string) bool {
if newVersion == oldVersion {
return true
}

newVersions := strings.Split(newVersion, ".")
oldVersions := strings.Split(oldVersion, ".")
var (
versionTrimmerRegexp = regexp.MustCompile(`^(\d+).*$`)

newVersions = strings.Split(newVersion, ".")
oldVersions = strings.Split(oldVersion, ".")
)
for idx, nv := range newVersions {
if len(oldVersions) <= idx {
return true
}

nvi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
nvi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(nv, "$1"))
ovi, _ := strconv.Atoi(versionTrimmerRegexp.ReplaceAllString(oldVersions[idx], "$1"))
if nvi == ovi {
continue
}
Expand Down

0 comments on commit 4a51687

Please sign in to comment.