diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 65a8abb..7e5fc24 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,6 +25,12 @@ jobs: MSSQL_PASSWORD: LoremIpsum86 ports: - 9930:1433 + options: >- + --health-cmd="/opt/mssql-tools18/bin/sqlcmd -C -S localhost -U sa -P LoremIpsum86 -l 30 -Q \"SELECT 1\" || exit 1" + --health-start-period 10s + --health-interval 10s + --health-timeout 5s + --health-retries 10 steps: - name: Set up Go 1.x diff --git a/migrator.go b/migrator.go index 3f80515..7ed0fd6 100644 --- a/migrator.go +++ b/migrator.go @@ -36,6 +36,58 @@ func (m Migrator) GetTables() (tableList []string, err error) { return tableList, m.DB.Raw("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_CATALOG = ?", m.CurrentDatabase()).Scan(&tableList).Error } +func (m Migrator) CreateTable(values ...interface{}) (err error) { + if err = m.Migrator.CreateTable(values...); err != nil { + return + } + for _, value := range m.ReorderModels(values, false) { + if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + if stmt.Schema == nil { + return + } + for _, fieldName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[fieldName] + if field.Comment == "" { + continue + } + if err = m.setColumnComment(stmt, field, true); err != nil { + return + } + } + return + }); err != nil { + return + } + } + return +} + +func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error { + schemaName := m.getTableSchemaName(stmt.Schema) + // add field comment + if add { + return m.DB.Exec( + "EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + field.Comment, schemaName, stmt.Table, field.DBName, + ).Error + } + // update field comment + return m.DB.Exec( + "EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + field.Comment, schemaName, stmt.Table, field.DBName, + ).Error +} + +func (m Migrator) getTableSchemaName(schema *schema.Schema) string { + // return the schema name if it is explicitly provided in the table name + // otherwise return default schema name + schemaName := getTableSchemaName(schema) + if schemaName == "" { + schemaName = m.DefaultSchema() + } + return schemaName +} + func getTableSchemaName(schema *schema.Schema) string { // return the schema name if it is explicitly provided in the table name // otherwise return a sql wildcard -> use any table_schema @@ -141,6 +193,26 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { ).Error } +func (m Migrator) AddColumn(value interface{}, name string) error { + if err := m.Migrator.AddColumn(value, name); err != nil { + return err + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + if field.Comment == "" { + return + } + if err = m.setColumnComment(stmt, field, true); err != nil { + return + } + } + } + return + }) +} + func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -200,6 +272,36 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) { + queryTx := m.DB + if m.DB.DryRun { + queryTx = m.DB.Session(&gorm.Session{}) + queryTx.DryRun = false + } + schemaName := m.getTableSchemaName(stmt.Schema) + queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)", + gorm.Expr(m.CurrentDatabase()), schemaName, stmt.Table, fieldDBName).Scan(&description) + return +} + +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { + return err + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + description := m.GetColumnComment(stmt, field.DBName) + if field.Comment != description { + if description == "" { + err = m.setColumnComment(stmt, field, true) + } else { + err = m.setColumnComment(stmt, field, false) + } + } + return + }) +} + var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$") // ColumnTypes return columnTypes []gorm.ColumnType and execErr error diff --git a/migrator_test.go b/migrator_test.go index fb01944..8f7ba8f 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -188,3 +188,60 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co } return } + +type TestTableFieldComment struct { + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;comment:姓名"` + Age uint `gorm:"column:age;comment:年龄"` +} + +func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" } + +type TestTableFieldCommentUpdate struct { + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;comment:姓名"` + Age uint `gorm:"column:age;comment:周岁"` + Birthday *time.Time `gorm:"column:birthday;comment:生日"` +} + +func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" } + +func TestMigrator_MigrateColumnComment(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Error(err) + } + migrator := db.Debug().Migrator() + + tableModel := new(TestTableFieldComment) + defer func() { + if err = migrator.DropTable(tableModel); err != nil { + t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) + } + }() + + if err = migrator.AutoMigrate(tableModel); err != nil { + t.Fatal(err) + } + tableModelUpdate := new(TestTableFieldCommentUpdate) + if err = migrator.AutoMigrate(tableModelUpdate); err != nil { + t.Error(err) + } + + if m, ok := migrator.(sqlserver.Migrator); ok { + stmt := db.Model(tableModelUpdate).Find(nil).Statement + if stmt == nil || stmt.Schema == nil { + t.Fatal("expected Statement.Schema, got nil") + } + wantComments := []string{"", "姓名", "周岁", "生日"} + gotComments := make([]string, len(stmt.Schema.DBNames)) + for i, fieldDBName := range stmt.Schema.DBNames { + comment := m.GetColumnComment(stmt, fieldDBName) + gotComments[i] = comment + } + if !reflect.DeepEqual(wantComments, gotComments) { + t.Fatalf("expected comments %#v, got %#v", wantComments, gotComments) + } + t.Logf("got comments: %#v", gotComments) + } +}