Skip to content

Commit

Permalink
feat: deal with the unique field migrator (#105)
Browse files Browse the repository at this point in the history
* feat: deal with the unique field migrator

* feat: rebuild the AlterColumn function
  • Loading branch information
Cheese authored Feb 8, 2023
1 parent e1a37d1 commit 70d48fe
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
42 changes: 39 additions & 3 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"fmt"
"strconv"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -49,15 +50,50 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) clause.Expr {
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
fullDataType := m.FullDataTypeOf(field)
if m.Dialector.DontSupportRenameColumnUnique {
fullDataType.SQL = strings.Replace(fullDataType.SQL, " UNIQUE ", " ", 1)
}

return m.DB.Exec(
"ALTER TABLE ? MODIFY COLUMN ? ?",
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.FullDataTypeOf(field),
clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fullDataType,
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}

func (m Migrator) TiDBVersion() (isTiDB bool, major, minor, patch int, err error) {
// TiDB version string looks like:
// "5.7.25-TiDB-v6.5.0" or "5.7.25-TiDB-v6.4.0-serverless"
tidbVersionArray := strings.Split(m.Dialector.ServerVersion, "-")
if len(tidbVersionArray) < 3 || tidbVersionArray[1] != "TiDB" {
// It isn't TiDB
return
}

rawVersion := strings.TrimPrefix(tidbVersionArray[2], "v")
realVersionArray := strings.Split(rawVersion, ".")
if major, err = strconv.Atoi(realVersionArray[0]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the major version is: %s", realVersionArray[0])
return
}

if minor, err = strconv.Atoi(realVersionArray[1]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the minor version is: %s", realVersionArray[0])
return
}

if patch, err = strconv.Atoi(realVersionArray[2]); err != nil {
err = fmt.Errorf("failed to parse the version of TiDB, the patch version is: %s", realVersionArray[0])
return
}

isTiDB = true
return
}

func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if !m.Dialector.DontSupportRenameColumn {
Expand Down Expand Up @@ -173,11 +209,11 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}

rawColumnTypes, err := rows.ColumnTypes()

if err != nil {
return err
}

if err := rows.Close(); err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type Config struct {
DontSupportRenameColumn bool
DontSupportForShareClause bool
DontSupportNullAsDefaultValue bool
DontSupportRenameColumnUnique bool
}

type Dialector struct {
Expand Down Expand Up @@ -138,6 +139,10 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
dialector.Config.DontSupportRenameColumn = true
dialector.Config.DontSupportForShareClause = true
}

if strings.Contains(dialector.ServerVersion, "TiDB") {
dialector.Config.DontSupportRenameColumnUnique = true
}
}

// register callbacks
Expand Down

0 comments on commit 70d48fe

Please sign in to comment.