Skip to content

Commit

Permalink
schema validation and improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
covrom committed Jul 29, 2022
1 parent 660c9dc commit a445386
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 37 deletions.
2 changes: 1 addition & 1 deletion category_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Category struct {
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt sql.NullTime
ParentID uuid.UUID
ParentID *uuid.UUID
Name string
IsDisabled bool
}
Expand Down
21 changes: 15 additions & 6 deletions cmd/goerd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ func main() {
}
f.Close()

qs := goerd.GenerateMigrationSQL(src, dst)
qs, err := goerd.GenerateMigrationSQL(src, dst)
if err != nil {
log.Fatal(err)
}
if cmdIsPrint {
for _, q := range qs {
if !*drop {
Expand Down Expand Up @@ -93,7 +96,7 @@ func main() {
}
f.Close()

wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -122,7 +125,10 @@ func main() {
log.Fatal(err)
}

qs := goerd.GenerateMigrationSQL(src, dst)
qs, err := goerd.GenerateMigrationSQL(src, dst)
if err != nil {
log.Fatal(err)
}
if cmdIsPrint {
for _, q := range qs {
if !*drop {
Expand Down Expand Up @@ -181,7 +187,7 @@ func main() {
log.Fatal(err)
}

wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
log.Fatal(err)
}
Expand All @@ -198,7 +204,7 @@ func main() {
log.Fatal(err)
}

wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
wr, err := os.OpenFile(*to, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644)
if err != nil {
log.Fatal(err)
}
Expand All @@ -219,7 +225,10 @@ func main() {
if err != nil {
log.Fatal(err)
}
qs := goerd.GenerateMigrationSQL(src, dst)
qs, err := goerd.GenerateMigrationSQL(src, dst)
if err != nil {
log.Fatal(err)
}
if cmdIsPrint {
for _, q := range qs {
if !*drop {
Expand Down
8 changes: 5 additions & 3 deletions goerd.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ func SchemaFromPostgresDB(db *sql.DB) (*schema.Schema, error) {

// GenerateMigrationSQL generates an array of SQL DDL queries
// for postgres that modify database tables, columns, indexes, etc.
func GenerateMigrationSQL(sfrom, sto *schema.Schema) []string {
func GenerateMigrationSQL(sfrom, sto *schema.Schema) ([]string, error) {
ptch := &schema.PatchSchema{CurrentSchema: sfrom.CurrentSchema}
ptch.Build(sfrom, sto)
return ptch.GenerateSQL()
if err := ptch.Build(sfrom, sto); err != nil {
return nil, err
}
return ptch.GenerateSQL(), nil
}

// SchemaToYAML saves the schema to a yaml file
Expand Down
15 changes: 9 additions & 6 deletions goerd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ func TestBasicUsage(t *testing.T) {

ctx := goerd.WithSqlxDb(context.Background(), db)

if err := prods.ProductToStore(ctx, p); err != nil {
t.Errorf("ProductToStore error: %s", err)
if err := cats.CategoryToStore(ctx, c); err != nil {
t.Errorf("CategoryToStore error: %s", err)
return
}

if err := cats.CategoryToStore(ctx, c); err != nil {
t.Errorf("CategoryToStore error: %s", err)
if err := prods.ProductToStore(ctx, p); err != nil {
t.Errorf("ProductToStore error: %s", err)
return
}

Expand Down Expand Up @@ -100,7 +100,10 @@ func Migrate(d *sqlx.DB, migsch *schema.Schema) error {
if err != nil {
return fmt.Errorf("cannot migrate database: %w", err)
}
qs := goerd.GenerateMigrationSQL(dbsch, migsch)
qs, err := goerd.GenerateMigrationSQL(dbsch, migsch)
if err != nil {
return err
}
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("cannot migrate database: %w", err)
Expand Down Expand Up @@ -128,7 +131,7 @@ func Migrate(d *sqlx.DB, migsch *schema.Schema) error {
fmt.Println("target schema:")
migsch.SaveYaml(os.Stdout)

return fmt.Errorf("cannot migrate database: %w", err)
return fmt.Errorf("cannot migrate database %q: %w", q, err)
}
}
if err = tx.Commit(); err != nil {
Expand Down
5 changes: 4 additions & 1 deletion objmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ func (md *ModelSet) Migrate(d *sqlx.DB, dbSchema string) error {
return fmt.Errorf("cannot migrate database: %w", err)
}

qs := GenerateMigrationSQL(dbsch, migsch)
qs, err := GenerateMigrationSQL(dbsch, migsch)
if err != nil {
return fmt.Errorf("cannot migrate database: %w", err)
}
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("cannot migrate database: %w", err)
Expand Down
15 changes: 13 additions & 2 deletions schema/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ func createIndexDDL(idx *Index) string {
func (i *PatchIndex) create() []string {
return []string{createIndexDDL(i.to)}
}

func (i *PatchIndex) alter() []string {
if i.from.MethodName == "" {
i.from.MethodName = "btree"
Expand All @@ -249,6 +250,7 @@ func (i *PatchIndex) alter() []string {
}
return append(i.drop(), i.create()...)
}

func (i *PatchIndex) drop() []string {
// always drop unused indexes
return []string{
Expand Down Expand Up @@ -361,13 +363,15 @@ func createRelationDDL(r *Relation) string {
func (r *PatchRelation) create() []string {
return []string{createRelationDDL(r.to)}
}

func (r *PatchRelation) alter() []string {
if strings.EqualFold(createRelationDDL(r.from),
createRelationDDL(r.to)) {
return nil
}
return append(r.drop(), r.create()...)
}

func (r *PatchRelation) drop() []string {
// TODO:
// declare r record;
Expand All @@ -391,7 +395,7 @@ type PatchSchema struct {
}

func (t *PatchSchema) GenerateSQL() (ret []string) {
// TODO: using CurrentSchema
// TODO: using CurrentSchema as schema prefix
for _, st := range t.tables {
ret = append(ret, st.GenerateSQL()...)
}
Expand All @@ -401,7 +405,13 @@ func (t *PatchSchema) GenerateSQL() (ret []string) {
return
}

func (s *PatchSchema) Build(from, to *Schema) {
func (s *PatchSchema) Build(from, to *Schema) error {
if err := from.Validate(); err != nil {
return fmt.Errorf("source schema validation error: %w", err)
}
if err := to.Validate(); err != nil {
return fmt.Errorf("target schema validation error: %w", err)
}
s.CurrentSchema = to.CurrentSchema
s.tables = make([]*PatchTable, 0, len(from.Tables)+len(to.Tables))
s.relations = make([]*PatchRelation, 0, len(from.Relations)+len(to.Relations))
Expand Down Expand Up @@ -609,4 +619,5 @@ func (s *PatchSchema) Build(from, to *Schema) {
s.relations = append(s.relations, pt)
}
}
return nil
}
41 changes: 24 additions & 17 deletions schema/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
)

func TestPatchSchema_BuildDropAndNew(t *testing.T) {

t.Run("1", func(t *testing.T) {
from := &Schema{
CurrentSchema: "public",
Expand Down Expand Up @@ -65,7 +64,10 @@ func TestPatchSchema_BuildDropAndNew(t *testing.T) {
}

s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `DROP TABLE IF EXISTS table_old
Expand All @@ -79,11 +81,9 @@ CREATE INDEX table1_col3 ON table1(column3)` {
t.Error(qss)
}
})

}

func TestPatchSchema_BuildAddColIdx(t *testing.T) {

t.Run("1", func(t *testing.T) {
from := &Schema{
CurrentSchema: "public",
Expand Down Expand Up @@ -152,7 +152,10 @@ func TestPatchSchema_BuildAddColIdx(t *testing.T) {
}

s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `ALTER TABLE table1 ADD COLUMN column3 uuid NOT NULL
Expand All @@ -161,11 +164,9 @@ ALTER TABLE table1 ADD CONSTRAINT table1_constraint_check CHECK (true)` {
t.Error(qss)
}
})

}

func TestPatchSchema_BuildEq(t *testing.T) {

t.Run("1", func(t *testing.T) {
from := &Schema{
CurrentSchema: "public",
Expand Down Expand Up @@ -220,18 +221,19 @@ func TestPatchSchema_BuildEq(t *testing.T) {
}

s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `` {
t.Error(qss)
}
})

}

func TestPatchSchema_BuildChangeCol(t *testing.T) {

t.Run("1", func(t *testing.T) {
from := &Schema{
CurrentSchema: "public",
Expand Down Expand Up @@ -286,18 +288,19 @@ func TestPatchSchema_BuildChangeCol(t *testing.T) {
}

s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `ALTER TABLE table1 ALTER COLUMN column2 TYPE text` {
t.Error(qss)
}
})

}

func TestPatchSchema_BuildChangeIdx(t *testing.T) {

t.Run("1", func(t *testing.T) {
from := &Schema{
CurrentSchema: "public",
Expand Down Expand Up @@ -356,7 +359,10 @@ func TestPatchSchema_BuildChangeIdx(t *testing.T) {
}

s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `ALTER TABLE table1 ADD COLUMN column3 uuid NOT NULL
Expand All @@ -365,7 +371,6 @@ CREATE INDEX table1_col2 ON table1(column2,column3)` {
t.Error(qss)
}
})

}

func TestPatchSchema_BuildChangeIdx2(t *testing.T) {
Expand Down Expand Up @@ -428,12 +433,14 @@ func TestPatchSchema_BuildChangeIdx2(t *testing.T) {

t.Run("1", func(t *testing.T) {
s := &PatchSchema{}
s.Build(from, to)
if err := s.Build(from, to); err != nil {
t.Error(err)
return
}
qs := s.GenerateSQL()
qss := strings.Join(qs, "\n")
if qss != `` {
t.Error(qss)
}
})

}
Loading

0 comments on commit a445386

Please sign in to comment.