Skip to content

Commit

Permalink
reduce packages.Load calls
Browse files Browse the repository at this point in the history
  • Loading branch information
vikstrous committed Nov 26, 2019
1 parent d3f6384 commit 697dd11
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 72 deletions.
2 changes: 1 addition & 1 deletion api/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func validate(cfg *config.Config) error {
if cfg.Resolver.IsDefined() {
roots = append(roots, cfg.Resolver.ImportPath())
}
_, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, roots...)
_, err := packages.Load(&packages.Config{Mode: packages.LoadSyntax}, roots...)
if err != nil {
return errors.Wrap(err, "validation failed")
}
Expand Down
7 changes: 1 addition & 6 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ type Binder struct {
SawInvalid bool
}

func (c *Config) NewBinder(s *ast.Schema) (*Binder, error) {
pkgs, err := packages.Load(&packages.Config{Mode: packages.LoadTypes | packages.LoadSyntax}, c.Models.ReferencedPackages()...)
if err != nil {
return nil, err
}

func (c *Config) NewBinder(s *ast.Schema, pkgs []*packages.Package) (*Binder, error) {
mp := map[string]*packages.Package{}
var pkgErrs PkgErrors
for _, p := range pkgs {
Expand Down
7 changes: 6 additions & 1 deletion codegen/config/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

func TestBindingToInvalid(t *testing.T) {
Expand Down Expand Up @@ -58,7 +59,11 @@ func createBinder(cfg Config) (*Binder, *ast.Schema) {
}
`})

b, err := cfg.NewBinder(s)
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo}, "github.com/99designs/gqlgen/example/chat")
if err != nil {
panic(err)
}
b, err := cfg.NewBinder(s, pkgs)
if err != nil {
panic(err)
}
Expand Down
27 changes: 22 additions & 5 deletions codegen/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,21 +385,35 @@ func (c *Config) normalize() error {
return nil
}

func (c *Config) Autobind(s *ast.Schema) error {
func (c *Config) isAutobind(pkg *packages.Package) bool {
for _, ab := range c.AutoBind {
if strings.HasSuffix(ab, "/...") {
abPrefix := strings.TrimSuffix(ab, "/...")
if strings.HasPrefix(pkg.PkgPath, abPrefix) {
return true
}
}
if pkg.PkgPath == ab {
return true
}
}
return false
}

func (c *Config) Autobind(s *ast.Schema, ps []*packages.Package) error {
if len(c.AutoBind) == 0 {
return nil
}
ps, err := packages.Load(&packages.Config{Mode: packages.LoadTypes}, c.AutoBind...)
if err != nil {
return err
}

for _, t := range s.Types {
if c.Models.UserDefined(t.Name) {
continue
}

for _, p := range ps {
if !c.isAutobind(p) {
continue
}
if t := p.Types.Scope().Lookup(t.Name); t != nil {
c.Models.Add(t.Name(), t.Pkg().Path()+"."+t.Name())
break
Expand All @@ -417,6 +431,9 @@ func (c *Config) Autobind(s *ast.Schema) error {
}

for _, p := range ps {
if !c.isAutobind(p) {
continue
}
if p.Name != pkg {
continue
}
Expand Down
6 changes: 5 additions & 1 deletion codegen/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -132,7 +133,10 @@ func TestAutobinding(t *testing.T) {
type Message { id: ID }
`})

require.NoError(t, cfg.Autobind(s))
ps, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedTypes}, cfg.AutoBind...)
require.NoError(t, err)

require.NoError(t, cfg.Autobind(s, ps))

require.Equal(t, "github.com/99designs/gqlgen/example/scalars/model.Banned", cfg.Models["Banned"].Model[0])
require.Equal(t, "github.com/99designs/gqlgen/example/chat.Message", cfg.Models["Message"].Model[0])
Expand Down
31 changes: 22 additions & 9 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"sort"

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/internal/code"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
Expand All @@ -25,6 +27,9 @@ type Data struct {
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object

// This is important for looking up packages during code generation
NameForPackage code.NameForPackage
}

type builder struct {
Expand All @@ -51,14 +56,21 @@ func BuildData(cfg *config.Config) (*Data, error) {
return nil, err
}

err = cfg.Autobind(b.Schema)
// Thist must be before loading packages so that the built in packages are loaded
cfg.InjectBuiltins(b.Schema)

packageNames := append(cfg.AutoBind, cfg.Models.ReferencedPackages()...)
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedDeps | packages.NeedName | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo}, packageNames...)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "loading failed")
}

cfg.InjectBuiltins(b.Schema)
err = cfg.Autobind(b.Schema, pkgs)
if err != nil {
return nil, err
}

b.Binder, err = b.Config.NewBinder(b.Schema)
b.Binder, err = b.Config.NewBinder(b.Schema, pkgs)
if err != nil {
return nil, err
}
Expand All @@ -76,11 +88,12 @@ func BuildData(cfg *config.Config) (*Data, error) {
}

s := Data{
Config: cfg,
Directives: dataDirectives,
Schema: b.Schema,
SchemaStr: b.SchemaStr,
Interfaces: map[string]*Interface{},
Config: cfg,
Directives: dataDirectives,
Schema: b.Schema,
SchemaStr: b.SchemaStr,
Interfaces: map[string]*Interface{},
NameForPackage: code.NewNameForPackage(pkgs),
}

for _, schemaType := range b.Schema.Types {
Expand Down
1 change: 1 addition & 0 deletions codegen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ func GenerateCode(data *Data) error {
Data: data,
RegionTags: true,
GeneratedHeader: true,
NameForPackage: data.NameForPackage,
})
}
23 changes: 13 additions & 10 deletions codegen/templates/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ import (
)

type Import struct {
Name string
Path string
Alias string
NameForPackage code.NameForPackage
Name string
Path string
Alias string
}

type Imports struct {
imports []*Import
destDir string
nameForPackage code.NameForPackage
imports []*Import
destDir string
}

func (i *Import) String() string {
Expand Down Expand Up @@ -49,7 +51,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
return "", nil
}

name := code.NameForPackage(path)
name := s.nameForPackage.Get(path)
var alias string
if len(aliases) != 1 {
alias = name
Expand All @@ -69,9 +71,10 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
}

s.imports = append(s.imports, &Import{
Name: name,
Path: path,
Alias: alias,
NameForPackage: s.nameForPackage,
Name: name,
Path: path,
Alias: alias,
})

return "", nil
Expand All @@ -94,7 +97,7 @@ func (s *Imports) Lookup(path string) string {
}

imp := &Import{
Name: code.NameForPackage(path),
Name: s.nameForPackage.Get(path),
Path: path,
}
s.imports = append(s.imports, imp)
Expand Down
19 changes: 13 additions & 6 deletions codegen/templates/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"os"
"testing"

"github.com/99designs/gqlgen/internal/code"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestImports(t *testing.T) {
Expand All @@ -16,15 +18,20 @@ func TestImports(t *testing.T) {
bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar"
mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"

ps, err := packages.Load(nil, aBar, bBar, mismatch)
require.NoError(t, err)

nameForPackage := code.NewNameForPackage(ps)

t.Run("multiple lookups is ok", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar", a.Lookup(aBar))
})

t.Run("lookup by type", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

pkg := types.NewPackage("github.com/99designs/gqlgen/codegen/templates/testdata/b/bar", "bar")
typ := types.NewNamed(types.NewTypeName(0, pkg, "Boolean", types.Typ[types.Bool]), types.Typ[types.Bool], nil)
Expand All @@ -33,7 +40,7 @@ func TestImports(t *testing.T) {
})

t.Run("duplicates are decollisioned", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "bar", a.Lookup(aBar))
require.Equal(t, "bar1", a.Lookup(bBar))
Expand All @@ -44,13 +51,13 @@ func TestImports(t *testing.T) {
})

t.Run("package name defined in code will be used", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

require.Equal(t, "turtles", a.Lookup(mismatch))
})

t.Run("string printing for import block", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}
a.Lookup(aBar)
a.Lookup(bBar)
a.Lookup(mismatch)
Expand All @@ -65,7 +72,7 @@ turtles "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"`,
})

t.Run("aliased imports will not collide", func(t *testing.T) {
a := Imports{destDir: wd}
a := Imports{nameForPackage: nameForPackage, destDir: wd}

_, _ = a.Reserve(aBar, "abar")
_, _ = a.Reserve(bBar, "bbar")
Expand Down
12 changes: 7 additions & 5 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"text/template"
"unicode"

"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/internal/imports"
"github.com/pkg/errors"
)
Expand Down Expand Up @@ -43,6 +44,8 @@ type Options struct {
// Data will be passed to the template execution.
Data interface{}
Funcs template.FuncMap
// Lookups for pre-cached package names
NameForPackage code.NameForPackage
}

// Render renders a gql plugin template from the given Options. Render is an
Expand All @@ -53,7 +56,7 @@ func Render(cfg Options) error {
if CurrentImports != nil {
panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
}
CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}
CurrentImports = &Imports{nameForPackage: cfg.NameForPackage, destDir: filepath.Dir(cfg.Filename)}

// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
Expand Down Expand Up @@ -143,7 +146,7 @@ func Render(cfg Options) error {
}
CurrentImports = nil

return write(cfg.Filename, result.Bytes())
return write(cfg.Filename, result.Bytes(), cfg.NameForPackage)
}

func center(width int, pad string, s string) string {
Expand Down Expand Up @@ -551,13 +554,12 @@ func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
return buf, t.Execute(buf, tpldata)
}

func write(filename string, b []byte) error {
func write(filename string, b []byte, nameForPackage code.NameForPackage) error {
err := os.MkdirAll(filepath.Dir(filename), 0755)
if err != nil {
return errors.Wrap(err, "failed to create directory")
}

formatted, err := imports.Prune(filename, b)
formatted, err := imports.Prune(filename, b, nameForPackage)
if err != nil {
fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
formatted = b
Expand Down
Loading

0 comments on commit 697dd11

Please sign in to comment.