Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: create reflect source #1139

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions source/reflect/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Reflect driver

This driver allows you to define your up/down statements in a struct (even an anonymous struct). The driver uses reflection to examine the fields of the struct and return the correct statements.

Struct fields must end with '_up' or '_down' and the driver pairs matching fields automatically. The order of execution is the same as the order of definition. There doesn't need to be a matching down for each up statement, orphaned down statements are ignored.

The migration version is calculated automatically by default but if you want to, you can specify it manually by adding a `migrate` tag to each field of the struct. See example #2.

## Example 1 - auto version

```
migrations := &struct {
Users_up string
Users_down string
Folders_up string
Posts_up string
Posts_down string
}{

Users_up: `
create table users (
id int not null primary key,
created_at timestamp not null,
email text not null,
password text not null);`,

Users_down: `drop table users;`,

Folders_up: `create table folders (
id int not null primary key,
created_at timestamp not null,
label text not null);`,

Posts_up: `
create table posts (
id int not null primary key,
created_at timestamp not null,
user_id int not null,
body text not null);`,

Posts_down: `drop table posts;`,
}

driver, err := New(migrations)
if err != nil {
log.Fatal(err)
}

driver, err = driver.Open("reflect://")
if err != nil {
log.Fatal(err)
}
...

```

## Example 2 - struct tags

```
migrations := &struct {
Table1_up string `migrate:"1"`
Table1_down string `migrate:"1"`
Table2_up string `migrate:"3"`
Table3_up string `migrate:"4"`
Table3_down string `migrate:"4"`
Table4_down string `migrate:"5"`
Table5_down string `migrate:"7"`
Table5_up string `migrate:"7"`
}{
Table1_up: `test statement 1`,
Table1_down: `test statement 2`,
Table2_up: `test statement 3`,
Table3_up: `test statement 4`,
Table3_down: `test statement 5`,
Table4_down: `test statement 6`,
Table5_up: `test statement 7`,
Table5_down: `test statement 8`,
}
```
163 changes: 163 additions & 0 deletions source/reflect/reflect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package reflect

import (
"bytes"
"fmt"
"io"
"os"
"reflect"
"strconv"
"strings"

"github.com/golang-migrate/migrate/v4/source"
)

func init() {
source.Register("reflect", &reflectSource{})
}

type reflectSource struct {
target any
labels []string
up []string
down []string
}

func New(target any) (source.Driver, error) {
driver := &reflectSource{target: target}
return driver.Open("")
}

func (r *reflectSource) Open(url string) (source.Driver, error) {
if r.target == nil {
return nil, fmt.Errorf("no target. source must be created with reflect.New()")
}

// already opened
if len(r.labels) > 0 {
return r, nil
}

// get the fields, these will always be in the same order as defined
t := reflect.TypeOf(r.target).Elem()

// the next step is to match the up clauses and down clauses
// keep track of them in a map so they can be found later
stubs := map[string]int{}

r.up = make([]string, t.NumField()+1)
r.down = make([]string, t.NumField()+1)
r.labels = make([]string, t.NumField()+1)
version := 1
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)

last := strings.LastIndexByte(f.Name, '_')
if last < 0 {
return nil, fmt.Errorf("struct field must end with '_up' or '_down': %s", f.Name)
}

tag := f.Tag.Get("migrate")
if tag != "" {
v, err := strconv.Atoi(tag)
if err != nil {
return nil, fmt.Errorf("invalid tag, expected number: %s", tag)
}
version = v
}

switch f.Name[last:] {
case "_up":
r.up[version] = f.Name
prefix := f.Name[:last]
stubs[prefix] = version
r.labels[version] = prefix
version++
case "_down":
prefix := f.Name[:last]
if ix, ok := stubs[prefix]; ok {
r.down[ix] = f.Name
} else {
r.down[version] = f.Name
stubs[prefix] = version
r.labels[version] = prefix
version++
}
default:
return nil, fmt.Errorf("struct field must end with '_up' or '_down': %s", f.Name)
}
}

return r, nil
}

func (r *reflectSource) Close() error {
// no-op
return nil
}

func (r *reflectSource) First() (version uint, err error) {
return 1, nil
}

func (r *reflectSource) Prev(version uint) (prevVersion uint, err error) {
v := int(version)
if v < 1 || v >= len(r.up) {
return 0, os.ErrNotExist
}
if r.up[v] == "" && r.down[v] == "" {
return 0, os.ErrNotExist
}
v--
for v > 0 {
if r.up[v] != "" {
return uint(v), nil
}
if r.down[v] != "" {
return uint(v), nil
}
v--
}
return 0, os.ErrNotExist
}

func (r *reflectSource) Next(version uint) (nextVersion uint, err error) {
v := int(version)
if v < 1 || v >= len(r.up) {
return 0, os.ErrNotExist
}
if r.up[v] == "" && r.down[v] == "" {
return 0, os.ErrNotExist
}
v++
for v < len(r.up) {
if r.up[v] != "" {
return uint(v), nil
}
if r.down[v] != "" {
return uint(v), nil
}
v++
}
return 0, os.ErrNotExist
}

func (r *reflectSource) ReadUp(version uint) (io.ReadCloser, string, error) {
ix := int(version)
if r.up[ix] == "" {
return nil, r.labels[ix], os.ErrNotExist
}
val := reflect.ValueOf(r.target).Elem()
field := val.FieldByName(r.up[ix])
return io.NopCloser(bytes.NewBufferString(field.String())), r.labels[ix], nil
}

func (r *reflectSource) ReadDown(version uint) (io.ReadCloser, string, error) {
ix := int(version)
if r.down[ix] == "" {
return nil, r.labels[ix], os.ErrNotExist
}
val := reflect.ValueOf(r.target).Elem()
field := val.FieldByName(r.down[ix])
return io.NopCloser(bytes.NewBufferString(field.String())), r.labels[ix], nil
}
118 changes: 118 additions & 0 deletions source/reflect/reflect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package reflect

import (
"io"
"log"
"testing"

st "github.com/golang-migrate/migrate/v4/source/testing"
)

func TestExample(t *testing.T) {
migrations := &struct {
Users_up string // v1
Users_down string // v1
Folders_up string // v2
Posts_up string // v3
Posts_down string // v3
}{

Users_up: `
create table users (
id int not null primary key,
created_at timestamp not null,
email text not null,
password text not null);`,

Users_down: `drop table users;`,

Folders_up: `create table folders (
id int not null primary key,
created_at timestamp not null,
label text not null);`,

Posts_up: `
create table posts (
id int not null primary key,
created_at timestamp not null,
user_id int not null,
body text not null);`,

Posts_down: `drop table posts;`,
}

driver, err := New(migrations)
if err != nil {
log.Fatal(err)
}

if driver == nil {
log.Fatal("driver should not be nil")
}

rdr, label, err := driver.ReadUp(3)
if err != nil {
log.Fatal(err)
}

txt, err := io.ReadAll(rdr)
if err != nil {
log.Fatal(err)
}

if string(txt) != migrations.Posts_up {
log.Fatal("unexpected text: " + string(txt))
}

if label != "Posts" {
log.Fatal("unexpected label")
}

rdr, label, err = driver.ReadDown(3)
if err != nil {
log.Fatal(err)
}

txt, err = io.ReadAll(rdr)
if err != nil {
log.Fatal(err)
}

if string(txt) != migrations.Posts_down {
log.Fatal("unexpected text: " + string(txt))
}

if label != "Posts" {
log.Fatal("unexpected label")
}

}

func Test(t *testing.T) {
migrations := &struct {
Table1_up string `migrate:"1"`
Table1_down string `migrate:"1"`
Table2_up string `migrate:"3"`
Table3_up string `migrate:"4"`
Table3_down string `migrate:"4"`
Table4_down string `migrate:"5"`
Table5_down string `migrate:"7"`
Table5_up string `migrate:"7"`
}{
Table1_up: `test statement 1`,
Table1_down: `test statement 2`,
Table2_up: `test statement 3`,
Table3_up: `test statement 4`,
Table3_down: `test statement 5`,
Table4_down: `test statement 6`,
Table5_up: `test statement 7`,
Table5_down: `test statement 8`,
}

driver, err := New(migrations)
if err != nil {
log.Fatal(err)
}

st.Test(t, driver)
}
Loading