Skip to content

Commit

Permalink
Generate enums
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Mar 28, 2018
1 parent 71c4e26 commit 85a5126
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 35 deletions.
3 changes: 3 additions & 0 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ModelBuild struct {
PackageName string
Imports Imports
Models []Model
Enums []Enum
}

// Create a list of models that need to be generated
Expand All @@ -45,6 +46,7 @@ func Models(schema *schema.Schema, userTypes map[string]string, destDir string)
return &ModelBuild{
PackageName: filepath.Base(destDir),
Models: models,
Enums: buildEnums(namedTypes, schema),
Imports: buildImports(namedTypes, destDir),
}
}
Expand All @@ -63,6 +65,7 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*

objects := buildObjects(namedTypes, schema, prog, imports)
inputs := buildInputs(namedTypes, schema, prog, imports)
buildEnums(namedTypes, schema)

b := &Build{
PackageName: filepath.Base(destDir),
Expand Down
12 changes: 12 additions & 0 deletions codegen/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package codegen

type Enum struct {
*NamedType

Values []EnumValue
}

type EnumValue struct {
Name string
Description string
}
37 changes: 37 additions & 0 deletions codegen/enum_build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package codegen

import (
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
)

func buildEnums(types NamedTypes, s *schema.Schema) []Enum {
var enums []Enum

for _, typ := range s.Types {
if strings.HasPrefix(typ.TypeName(), "__") {
continue
}
if e, ok := typ.(*schema.Enum); ok {
var values []EnumValue
for _, v := range e.Values {
values = append(values, EnumValue{v.Name, v.Desc})
}

enum := Enum{
NamedType: types[e.TypeName()],
Values: values,
}
enum.GoType = ucFirst(enum.GQLType)
enums = append(enums, enum)
}
}

sort.Slice(enums, func(i, j int) bool {
return strings.Compare(enums[i].GQLType, enums[j].GQLType) == -1
})

return enums
}
2 changes: 1 addition & 1 deletion codegen/templates/data.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions codegen/templates/models.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,42 @@ import (
}
{{- end }}
{{- end}}

{{ range $enum := .Enums }}
type {{.GoType}} string
const (
{{ range $value := .Values }}
{{$enum.GoType}}{{ .Name|toCamel }} {{$enum.GoType}} = {{.Name|quote}} {{with .Description}} // {{.}} {{end}}
{{- end }}
)

func (e {{.GoType}}) IsValid() bool {
switch e {
case {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.GoType }}{{ $element.Name|toCamel }}{{end}}:
return true
}
return false
}

func (e {{.GoType}}) String() string {
return string(e)
}

func (e *{{.GoType}}) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}

*e = {{.GoType}}(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid {{.GQLType}}", str)
}
return nil
}

func (e {{.GoType}}) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}

{{- end }}
26 changes: 26 additions & 0 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func Run(name string, tpldata interface{}) (*bytes.Buffer, error) {
"ucFirst": ucFirst,
"lcFirst": lcFirst,
"quote": strconv.Quote,
"toCamel": toCamel,
"dump": dump,
})

Expand Down Expand Up @@ -54,6 +55,31 @@ func lcFirst(s string) string {
return string(r)
}

func isDelimiter(c rune) bool {
return c == '-' || c == '_' || unicode.IsSpace(c)
}

func toCamel(s string) string {
buffer := make([]rune, 0, len(s))
upper := true

for _, c := range s {
if isDelimiter(c) {
upper = true
continue
}

if upper {
buffer = append(buffer, unicode.ToUpper(c))
} else {
buffer = append(buffer, unicode.ToLower(c))
}
upper = false
}

return string(buffer)
}

func dump(val interface{}) string {
switch val := val.(type) {
case int:
Expand Down
36 changes: 18 additions & 18 deletions example/starwars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ type Resolvers interface {
Human_friendsConnection(ctx context.Context, obj *Human, first *int, after *string) (FriendsConnection, error)

Human_starships(ctx context.Context, obj *Human) ([]Starship, error)
Mutation_createReview(ctx context.Context, episode string, review Review) (*Review, error)
Mutation_createReview(ctx context.Context, episode Episode, review Review) (*Review, error)

Query_hero(ctx context.Context, episode string) (Character, error)
Query_reviews(ctx context.Context, episode string, since *time.Time) ([]Review, error)
Query_hero(ctx context.Context, episode Episode) (Character, error)
Query_reviews(ctx context.Context, episode Episode, since *time.Time) ([]Review, error)
Query_search(ctx context.Context, text string) ([]SearchResult, error)
Query_character(ctx context.Context, id string) (Character, error)
Query_droid(ctx context.Context, id string) (*Droid, error)
Expand Down Expand Up @@ -202,7 +202,7 @@ func (ec *executionContext) _Droid_appearsIn(ctx context.Context, field graphql.
res := obj.AppearsIn
arr1 := graphql.Array{}
for idx1 := range res {
arr1 = append(arr1, func() graphql.Marshaler { return graphql.MarshalString(res[idx1]) }())
arr1 = append(arr1, func() graphql.Marshaler { return res[idx1] }())
}
return arr1
}
Expand Down Expand Up @@ -377,18 +377,18 @@ func (ec *executionContext) _Human_name(ctx context.Context, field graphql.Colle
}

func (ec *executionContext) _Human_height(ctx context.Context, field graphql.CollectedField, obj *Human) graphql.Marshaler {
var arg0 string
var arg0 LengthUnit
if tmp, ok := field.Args["unit"]; ok {
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
}
} else {
var tmp interface{} = "METER"
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -478,7 +478,7 @@ func (ec *executionContext) _Human_appearsIn(ctx context.Context, field graphql.
res := obj.AppearsIn
arr1 := graphql.Array{}
for idx1 := range res {
arr1 = append(arr1, func() graphql.Marshaler { return graphql.MarshalString(res[idx1]) }())
arr1 = append(arr1, func() graphql.Marshaler { return res[idx1] }())
}
return arr1
}
Expand Down Expand Up @@ -529,10 +529,10 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel []query.Selection
}

func (ec *executionContext) _Mutation_createReview(ctx context.Context, field graphql.CollectedField) graphql.Marshaler {
var arg0 string
var arg0 Episode
if tmp, ok := field.Args["episode"]; ok {
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -639,18 +639,18 @@ func (ec *executionContext) _Query(ctx context.Context, sel []query.Selection) g
}

func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.CollectedField) graphql.Marshaler {
var arg0 string
var arg0 Episode
if tmp, ok := field.Args["episode"]; ok {
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
}
} else {
var tmp interface{} = "NEWHOPE"
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand All @@ -676,10 +676,10 @@ func (ec *executionContext) _Query_hero(ctx context.Context, field graphql.Colle
}

func (ec *executionContext) _Query_reviews(ctx context.Context, field graphql.CollectedField) graphql.Marshaler {
var arg0 string
var arg0 Episode
if tmp, ok := field.Args["episode"]; ok {
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down Expand Up @@ -982,18 +982,18 @@ func (ec *executionContext) _Starship_name(ctx context.Context, field graphql.Co
}

func (ec *executionContext) _Starship_length(ctx context.Context, field graphql.CollectedField, obj *Starship) graphql.Marshaler {
var arg0 string
var arg0 LengthUnit
if tmp, ok := field.Args["unit"]; ok {
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
}
} else {
var tmp interface{} = "METER"
var err error
arg0, err = graphql.UnmarshalString(tmp)
err = (&arg0).UnmarshalGQL(tmp)
if err != nil {
ec.Error(err)
return graphql.Null
Expand Down
6 changes: 3 additions & 3 deletions example/starwars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type CharacterFields struct {
ID string
Name string
FriendIds []string
AppearsIn []string
AppearsIn []Episode
}

type Human struct {
Expand All @@ -23,7 +23,7 @@ type Human struct {
Mass float64
}

func (h *Human) Height(unit string) float64 {
func (h *Human) Height(unit LengthUnit) float64 {
switch unit {
case "METER", "":
return h.heightMeters
Expand All @@ -41,7 +41,7 @@ type Starship struct {
lengthMeters float64
}

func (s *Starship) Length(unit string) float64 {
func (s *Starship) Length(unit LengthUnit) float64 {
switch unit {
case "METER", "":
return s.lengthMeters
Expand Down
Loading

0 comments on commit 85a5126

Please sign in to comment.