Skip to content

Commit

Permalink
single packages.Load for NameForPackage
Browse files Browse the repository at this point in the history
  • Loading branch information
vikstrous committed Jan 9, 2020
1 parent 28c032d commit 31341b1
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 15 deletions.
8 changes: 8 additions & 0 deletions codegen/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"sort"

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/internal/code"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/formatter"
"golang.org/x/tools/go/packages"
)

// Data is a unified model of the code to be generated. Plugins may modify this structure to do things like implement
Expand Down Expand Up @@ -88,6 +90,12 @@ func BuildData(cfg *config.Config, plugins []SchemaMutator) (*Data, error) {
}
}

pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...)
if err != nil {
return nil, errors.Wrap(err, "loading failed")
}
code.RecordPackagesList(pkgs)

s := Data{
Config: cfg,
Directives: dataDirectives,
Expand Down
7 changes: 7 additions & 0 deletions codegen/templates/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"os"
"testing"

"github.com/99designs/gqlgen/internal/code"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestImports(t *testing.T) {
Expand All @@ -16,6 +18,11 @@ func TestImports(t *testing.T) {
bBar := "github.com/99designs/gqlgen/codegen/templates/testdata/b/bar"
mismatch := "github.com/99designs/gqlgen/codegen/templates/testdata/pkg_mismatch"

ps, err := packages.Load(nil, aBar, bBar, mismatch)
require.NoError(t, err)

code.RecordPackagesList(ps)

t.Run("multiple lookups is ok", func(t *testing.T) {
a := Imports{destDir: wd}

Expand Down
34 changes: 22 additions & 12 deletions internal/code/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package code

import (
"errors"
"fmt"
"go/build"
"go/parser"
"go/token"
Expand All @@ -14,7 +15,8 @@ import (
"golang.org/x/tools/go/packages"
)

var nameForPackageCache = sync.Map{}
var nameForPackageCacheLock sync.Mutex
var nameForPackageCache []*packages.Package

var gopaths []string

Expand Down Expand Up @@ -107,24 +109,32 @@ func ImportPathForDir(dir string) (res string) {

var modregex = regexp.MustCompile("module (.*)\n")

// RecordPackagesList records the list of packages to be used later by NameForPackage.
// It must be called exactly once during initialization, before NameForPackage is called.
func RecordPackagesList(newNameForPackageCache []*packages.Package) {
nameForPackageCache = newNameForPackageCache
}

// NameForPackage returns the package name for a given import path. This can be really slow.
func NameForPackage(importPath string) string {
if importPath == "" {
panic(errors.New("import path can not be empty"))
}
if v, ok := nameForPackageCache.Load(importPath); ok {
return v.(string)
if nameForPackageCache == nil {
panic(fmt.Errorf("NameForPackage called for %s before RecordPackagesList", importPath))
}
nameForPackageCacheLock.Lock()
defer nameForPackageCacheLock.Unlock()
var p *packages.Package
for _, pkg := range nameForPackageCache {
if pkg.PkgPath == importPath {
p = pkg
break
}
}
importPath = QualifyPackagePath(importPath)
p, _ := packages.Load(&packages.Config{
Mode: packages.NeedName,
}, importPath)

if len(p) != 1 || p[0].Name == "" {
if p == nil || p.Name == "" {
return SanitizePackageName(filepath.Base(importPath))
}

nameForPackageCache.Store(importPath, p[0].Name)

return p[0].Name
return p.Name
}
13 changes: 10 additions & 3 deletions internal/code/imports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestImportPathForDir(t *testing.T) {
Expand All @@ -31,11 +32,17 @@ func TestImportPathForDir(t *testing.T) {
}

func TestNameForPackage(t *testing.T) {
assert.Equal(t, "api", NameForPackage("github.com/99designs/gqlgen/api"))
testPkg1 := "github.com/99designs/gqlgen/api"
testPkg2 := "github.com/99designs/gqlgen/docs"
testPkg3 := "github.com"
ps, err := packages.Load(nil, testPkg1, testPkg2, testPkg3)
require.NoError(t, err)
RecordPackagesList(ps)
assert.Equal(t, "api", NameForPackage(testPkg1))

// does not contain go code, should still give a valid name
assert.Equal(t, "docs", NameForPackage("github.com/99designs/gqlgen/docs"))
assert.Equal(t, "github_com", NameForPackage("github.com"))
assert.Equal(t, "docs", NameForPackage(testPkg2))
assert.Equal(t, "github_com", NameForPackage(testPkg3))
}

func TestNameForDir(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions internal/imports/prune_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ import (
"io/ioutil"
"testing"

"github.com/99designs/gqlgen/internal/code"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/packages"
)

func TestPrune(t *testing.T) {
// prime the packages cache so that it's not considered uninitialized
code.RecordPackagesList([]*packages.Package{})

b, err := Prune("testdata/unused.go", mustReadFile("testdata/unused.go"))
require.NoError(t, err)
require.Equal(t, string(mustReadFile("testdata/unused.expected.go")), string(b))
Expand Down
9 changes: 9 additions & 0 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ import (

"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/internal/code"
"github.com/99designs/gqlgen/plugin"
"github.com/pkg/errors"
"github.com/vektah/gqlparser/ast"
"golang.org/x/tools/go/packages"
)

type BuildMutateHook = func(b *ModelBuild) *ModelBuild
Expand Down Expand Up @@ -246,6 +249,12 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
b = m.MutateHook(b)
}

pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, cfg.Models.ReferencedPackages()...)
if err != nil {
return errors.Wrap(err, "loading failed")
}
code.RecordPackagesList(pkgs)

return templates.Render(templates.Options{
PackageName: cfg.Model.Package,
Filename: cfg.Model.Filename,
Expand Down

0 comments on commit 31341b1

Please sign in to comment.