diff --git a/generator.go b/generator.go index 3d8fa2f..cb68bb2 100644 --- a/generator.go +++ b/generator.go @@ -109,8 +109,7 @@ func (g *Generator) generateUserTypes(outdir string, api *design.APIDefinition) codegen.SimpleImport("github.com/goadesign/goa"), codegen.SimpleImport("github.com/jinzhu/gorm"), codegen.SimpleImport("golang.org/x/net/context"), - codegen.SimpleImport("golang.org/x/net/context"), - codegen.SimpleImport("github.com/goadesign/goa/uuid"), + codegen.NewImport("uuid", "github.com/satori/go.uuid"), } if model.Cached { @@ -168,8 +167,7 @@ func (g *Generator) generateUserHelpers(outdir string, api *design.APIDefinition codegen.SimpleImport("github.com/goadesign/goa"), codegen.SimpleImport("github.com/jinzhu/gorm"), codegen.SimpleImport("golang.org/x/net/context"), - codegen.SimpleImport("golang.org/x/net/context"), - codegen.SimpleImport("github.com/goadesign/goa/uuid"), + codegen.NewImport("uuid", "github.com/satori/go.uuid"), } if model.Cached { diff --git a/relationalfield.go b/relationalfield.go index f3aa15f..a6b72f4 100644 --- a/relationalfield.go +++ b/relationalfield.go @@ -104,11 +104,13 @@ func goDatatype(f *RelationalFieldDefinition, includePtr bool) string { case Timestamp, NullableTimestamp: return ptr + "time.Time" case BelongsTo: - return ptr + "int" + return ptr + belongsToIDType(f, includePtr) case HasMany: return fmt.Sprintf("[]%s", f.HasMany) - case HasManyKey, HasOneKey: - return ptr + "int" + case HasManyKey: + return ptr + hasManyIDType(f, includePtr) + case HasOneKey: + return ptr + hasOneIDType(f, includePtr) case HasOne: return fmt.Sprintf("%s", f.HasOne) default: @@ -121,6 +123,51 @@ func goDatatype(f *RelationalFieldDefinition, includePtr bool) string { return "UNKNOWN TYPE" } +func goDatatypeByModel(m *RelationalModelDefinition, belongsToModelName string) string { + f := m.RelationalFields[belongsToModelName+"ID"] + if f == nil { + return "int" + } + return belongsToIDType(f, true) +} + +func belongsToIDType(f *RelationalFieldDefinition, includePtr bool) string { + if f.Parent == nil { + return "int" + } + modelName := strings.Replace(f.FieldName, "ID", "", -1) + model := f.Parent.BelongsTo[modelName] + return relatedIDType(model, includePtr) +} + +func hasOneIDType(f *RelationalFieldDefinition, includePtr bool) string { + if f.Parent == nil { + return "int" + } + modelName := strings.Replace(f.FieldName, "ID", "", -1) + model := f.Parent.HasOne[modelName] + return relatedIDType(model, includePtr) +} + +func hasManyIDType(f *RelationalFieldDefinition, includePtr bool) string { + if f.Parent == nil { + return "int" + } + modelName := strings.Replace(f.FieldName, "ID", "", -1) + model := f.Parent.HasMany[modelName] + return relatedIDType(model, includePtr) +} + +func relatedIDType(m *RelationalModelDefinition, includePtr bool) string { + if m == nil { + return "int" + } + if len(m.PrimaryKeys) > 1 { + panic("Can't determine field Type when using multiple primary keys") + } + return goDatatype(m.PrimaryKeys[0], includePtr) +} + func tags(f *RelationalFieldDefinition) string { var sqltags []string if f.SQLTag != "" { diff --git a/writers.go b/writers.go index b43a831..91bf0ac 100644 --- a/writers.go +++ b/writers.go @@ -355,6 +355,7 @@ func (w *UserTypesWriter) Execute(data *UserTypeTemplateData) error { fm["viewFields"] = viewFields fm["viewFieldNames"] = viewFieldNames fm["goDatatype"] = goDatatype + fm["goDatatypeByModel"] = goDatatypeByModel fm["plural"] = inflect.Pluralize fm["gtt"] = codegen.GoTypeTransform fm["gttn"] = codegen.GoTypeTransformName @@ -450,7 +451,7 @@ func (m *{{$ut.ModelName}}DB) TableName() string { // Belongs To Relationships // {{$ut.ModelName}}FilterBy{{$bt.ModelName}} is a gorm filter for a Belongs To relationship. -func {{$ut.ModelName}}FilterBy{{$bt.ModelName}}({{goify (printf "%s%s" $bt.ModelName "ID") false}} int, originaldb *gorm.DB) func(db *gorm.DB) *gorm.DB { +func {{$ut.ModelName}}FilterBy{{$bt.ModelName}}({{goify (printf "%s%s" $bt.ModelName "ID") false}} {{ goDatatypeByModel $ut $bt.ModelName }}, originaldb *gorm.DB) func(db *gorm.DB) *gorm.DB { if {{goify (printf "%s%s" $bt.ModelName "ID") false}} > 0 { return func(db *gorm.DB) *gorm.DB { return db.Where("{{if $bt.RelationalFields.ID.DatabaseFieldName}}{{ if ne $bt.RelationalFields.ID.DatabaseFieldName "id" }}{{$bt.RelationalFields.ID.DatabaseFieldName}} = ?", {{goify (printf "%s%s" $bt.ModelName "ID") false}}){{else}}{{$bt.LowerName}}_id = ?", {{goify (printf "%s%s" $bt.ModelName "ID") false}}){{end}}