Skip to content

Commit

Permalink
Merge pull request #13 from vektah/autocast
Browse files Browse the repository at this point in the history
Automatically add type conversions around wrapped types
  • Loading branch information
vektah authored Feb 19, 2018
2 parents c8c2e40 + 85fa63b commit 9d896f4
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 74 deletions.
10 changes: 5 additions & 5 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*

b := &Build{
PackageName: filepath.Base(destDir),
Objects: buildObjects(namedTypes, schema, prog),
Objects: buildObjects(namedTypes, schema, prog, imports),
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: buildInputs(namedTypes, schema, prog),
Inputs: buildInputs(namedTypes, schema, prog, imports),
Imports: imports,
}

Expand All @@ -56,19 +56,19 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
// Poke a few magic methods into query
q := b.Objects.ByName(b.QueryRoot.GQLType)
q.Fields = append(q.Fields, Field{
Type: &Type{namedTypes["__Schema"], []string{modPtr}},
Type: &Type{namedTypes["__Schema"], []string{modPtr}, ""},
GQLName: "__schema",
NoErr: true,
GoMethodName: "ec.introspectSchema",
Object: q,
})
q.Fields = append(q.Fields, Field{
Type: &Type{namedTypes["__Type"], []string{modPtr}},
Type: &Type{namedTypes["__Type"], []string{modPtr}, ""},
GQLName: "__type",
NoErr: true,
GoMethodName: "ec.introspectType",
Args: []FieldArgument{
{GQLName: "name", Type: &Type{namedTypes["String"], []string{}}},
{GQLName: "name", Type: &Type{namedTypes["String"], []string{}, ""}},
},
Object: q,
})
Expand Down
4 changes: 2 additions & 2 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"golang.org/x/tools/go/loader"
)

func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects {
var inputs Objects

for _, typ := range s.Types {
Expand All @@ -25,7 +25,7 @@ func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program)
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindObject(def.Type(), input)
bindObject(def.Type(), input, imports)
}

inputs = append(inputs, input)
Expand Down
9 changes: 0 additions & 9 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,6 @@ type FieldArgument struct {

type Objects []*Object

func (o *Object) GetField(name string) *Field {
for i, field := range o.Fields {
if strings.EqualFold(field.GQLName, name) {
return &o.Fields[i]
}
}
return nil
}

func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.GQLType)
for _, s := range o.Satisfies {
Expand Down
4 changes: 2 additions & 2 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"golang.org/x/tools/go/loader"
)

func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects {
var objects Objects

for _, typ := range s.Types {
Expand All @@ -23,7 +23,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
bindObject(def.Type(), obj)
bindObject(def.Type(), obj, imports)
}

objects = append(objects, obj)
Expand Down
30 changes: 27 additions & 3 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Type struct {
*NamedType

Modifiers []string
CastType string // the type to cast to when unmarshalling
}

const (
Expand All @@ -46,6 +47,15 @@ func (t Type) Signature() string {
return strings.Join(t.Modifiers, "") + t.FullName()
}

func (t Type) FullSignature() string {
pkg := ""
if t.Package != "" {
pkg = t.Package + "."
}

return strings.Join(t.Modifiers, "") + pkg + t.GoType
}

func (t Type) IsPtr() bool {
return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr
}
Expand All @@ -59,18 +69,32 @@ func (t NamedType) IsMarshaled() bool {
}

func (t Type) Unmarshal(result, raw string) string {
if t.Marshaler != nil {
return result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")"
realResult := result
if t.CastType != "" {
result = "castTmp"
}
return tpl(`var {{.result}} {{.type}}
ret := tpl(`var {{.result}} {{.type}}
err := (&{{.result}}).UnmarshalGQL({{.raw}})`, map[string]interface{}{
"result": result,
"raw": raw,
"type": t.FullName(),
})

if t.Marshaler != nil {
ret = result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")"
}

if t.CastType != "" {
ret += "\n" + realResult + " := " + t.CastType + "(castTmp)"
}
return ret
}

func (t Type) Marshal(result, val string) string {
if t.CastType != "" {
val = t.GoType + "(" + val + ")"
}

if t.Marshaler != nil {
return result + " = " + t.Marshaler.pkgDot() + "Marshal" + t.Marshaler.GoType + "(" + val + ")"
}
Expand Down
129 changes: 82 additions & 47 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,69 +47,104 @@ func isMethod(t types.Object) bool {
return f.Type().(*types.Signature).Recv() != nil
}

func bindObject(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}
func findMethod(typ *types.Named, name string) *types.Func {
for i := 0; i < typ.NumMethods(); i++ {
method := typ.Method(i)
if !method.Exported() {
continue
}

if strings.EqualFold(method.Name(), name) {
return method
}
}
return nil
}

func findField(typ *types.Struct, name string) *types.Var {
for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)
if !field.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)
if strings.EqualFold(field.Name(), name) {
return field
}
}
return nil
}

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())
func bindObject(t types.Type, object *Object, imports Imports) {
namedType, ok := t.(*types.Named)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return
}

// check arg order matches code, not gql
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return
}

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
for i := range object.Fields {
field := &object.Fields[i]
if method := findMethod(namedType, field.GQLName); method != nil {
sig := method.Type().(*types.Signature)
field.GoMethodName = "it." + method.Name()
field.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql
var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
field.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
if sig.Results().Len() == 1 {
field.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
continue
}

bindObject(t.Underlying(), object)
return true
if structField := findField(underlying, field.GQLName); structField != nil {
field.Type.Modifiers = modifiersFromGoType(structField.Type())
field.GoVarName = "it." + structField.Name()

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least
switch field.Type.FullSignature() {
case structField.Type().String():
// everything is fine

if !field.Exported() {
continue
}
case structField.Type().Underlying().String():
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPkg(pkg)
field.CastType = typ
if imp.Name != "" {
field.CastType = imp.Name + "." + typ
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
default:
fmt.Fprintf(os.Stderr, "type mismatch on %s.%s, expected %s got %s\n", object.GQLType, field.GQLName, field.Type.FullSignature(), structField.Type())
}
continue
}
t.Underlying()
return true
}

return false
if field.IsScalar {
fmt.Fprintf(os.Stderr, "unable to bind %s.%s to anything, %s has no suitable fields or methods\n", object.GQLType, field.GQLName, namedType.String())
}
}
}

func modifiersFromGoType(t types.Type) []string {
Expand Down
17 changes: 16 additions & 1 deletion example/scalars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ func (ec *executionContext) _user(sel []query.Selection, it *User) graphql.Marsh
res := it.Location

out.Values[i] = res
case "isBanned":
badArgs := false
if badArgs {
continue
}
res := it.IsBanned

out.Values[i] = graphql.MarshalBoolean(bool(res))
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand Down Expand Up @@ -807,13 +815,20 @@ func UnmarshalSearchArgs(v interface{}) (SearchArgs, error) {
return it, err
}
it.CreatedAfter = &val
case "isBanned":
castTmp, err := graphql.UnmarshalBoolean(v)
val := Banned(castTmp)
if err != nil {
return it, err
}
it.IsBanned = val
}
}

return it, nil
}

var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n}\n\nscalar Timestamp\nscalar Point\n")
var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n")

func (ec *executionContext) introspectSchema() *introspection.Schema {
return introspection.WrapSchema(parsedSchema)
Expand Down
4 changes: 4 additions & 0 deletions example/scalars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"github.com/vektah/gqlgen/graphql"
)

type Banned bool

type User struct {
ID string
Name string
Location Point // custom scalar types
Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods
IsBanned Banned // aliased primitive
}

// Point is serialized as a simple array, eg [1, 2]
Expand Down Expand Up @@ -71,4 +74,5 @@ func UnmarshalTimestamp(v interface{}) (time.Time, error) {
type SearchArgs struct {
Location *Point
CreatedAfter *time.Time
IsBanned Banned
}
2 changes: 2 additions & 0 deletions example/scalars/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ type User {
name: String!
created: Timestamp
location: Point
isBanned: Boolean!
}

input SearchArgs {
location: Point
createdAfter: Timestamp
isBanned: Boolean
}

scalar Timestamp
Expand Down
2 changes: 1 addition & 1 deletion example/starwars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (h *Human) Height(unit string) float64 {
type Starship struct {
ID string
Name string
History [][2]int
History [][]int
lengthMeters float64
}

Expand Down
Loading

0 comments on commit 9d896f4

Please sign in to comment.