Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the go:embed API to lookup templates #2262

Merged
merged 6 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _examples/embedding/subdir/gendir/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/embedding/subdir/root_.generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/accounts/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/products/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/federation/reviews/graph/generated/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions codegen/generate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package codegen

import (
"embed"
"errors"
"fmt"
"os"
Expand All @@ -13,6 +14,9 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)

//go:embed *.gotpl
var codegenTemplates embed.FS

func GenerateCode(data *Data) error {
if !data.Config.Exec.IsDefined() {
return fmt.Errorf("missing exec config")
Expand All @@ -36,6 +40,7 @@ func generateSingleFile(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}

Expand Down Expand Up @@ -82,6 +87,7 @@ func generatePerSchema(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
if err != nil {
return err
Expand Down Expand Up @@ -145,6 +151,7 @@ func generateRootFile(data *Data) error {
RegionTags: false,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}

Expand Down
91 changes: 49 additions & 42 deletions codegen/templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"go/types"
"io/fs"
"os"
"path/filepath"
"reflect"
Expand Down Expand Up @@ -35,6 +36,11 @@ type Options struct {
// the plugin processor will look for .gotpl files
// in the same directory of where you wrote the plugin.
Template string

// Use the go:embed API to collect all the template files you want to pass into Render
// this is an alternative to passing the Template option
TemplateFS fs.FS

// Filename is the name of the file that will be
// written to the system disk once the template is rendered.
Filename string
Expand Down Expand Up @@ -62,55 +68,27 @@ func Render(cfg Options) error {
}
CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}

// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)

funcs := Funcs()
for n, f := range cfg.Funcs {
funcs[n] = f
}

t := template.New("").Funcs(funcs)
t, err := parseTemplates(cfg, t)
if err != nil {
return err
}

var roots []string
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return fmt.Errorf("error with provided template: %w", err)
roots := make([]string, 0, len(t.Templates()))
for _, template := range t.Templates() {
// templates that end with _.gotpl are special files we don't want to include
if strings.HasSuffix(template.Name(), "_.gotpl") ||
// filter out templates added with {{ template xxx }} syntax inside the template file
!strings.HasSuffix(template.Name(), ".gotpl") {
continue
}
roots = append(roots, "template.gotpl")
} else {
// load all the templates in the directory
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
if !strings.HasSuffix(info.Name(), ".gotpl") {
return nil
}
// omit any templates with "_" at the end of their name, which are meant for specific contexts only
if strings.HasSuffix(info.Name(), "_.gotpl") {
return nil
}
b, err := os.ReadFile(path)
if err != nil {
return err
}

t, err = t.New(name).Parse(string(b))
if err != nil {
return fmt.Errorf("%s: %w", cfg.Filename, err)
}

roots = append(roots, name)

return nil
})
if err != nil {
return fmt.Errorf("locating templates: %w", err)
}
roots = append(roots, template.Name())
}

// then execute all the important looking ones in order, adding them to the same file
Expand All @@ -124,6 +102,7 @@ func Render(cfg Options) error {
}
return roots[i] < roots[j]
})

var buf bytes.Buffer
for _, root := range roots {
if cfg.RegionTags {
Expand Down Expand Up @@ -155,7 +134,7 @@ func Render(cfg Options) error {
result.WriteString("import (\n")
result.WriteString(CurrentImports.String())
result.WriteString(")\n")
_, err := buf.WriteTo(&result)
_, err = buf.WriteTo(&result)
if err != nil {
return err
}
Expand All @@ -170,6 +149,34 @@ func Render(cfg Options) error {
return nil
}

func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return nil, fmt.Errorf("error with provided template: %w", err)
}
return t, nil
}

var fileSystem fs.FS
if cfg.TemplateFS != nil {
fileSystem = cfg.TemplateFS
} else {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
fileSystem = os.DirFS(rootDir)
}

t, err := t.ParseFS(fileSystem, "*.gotpl")
if err != nil {
return nil, fmt.Errorf("locating templates: %w", err)
}

return t, nil
}

func center(width int, pad string, s string) string {
if len(s)+2 > width {
return s
Expand Down
33 changes: 33 additions & 0 deletions codegen/templates/templates_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package templates

import (
"embed"
"os"
"path/filepath"
"testing"

"github.com/99designs/gqlgen/internal/code"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

//go:embed *.gotpl
var templateFS embed.FS

func TestToGo(t *testing.T) {
require.Equal(t, "ToCamel", ToGo("TO_CAMEL"))
require.Equal(t, "ToCamel", ToGo("to_camel"))
Expand Down Expand Up @@ -119,3 +125,30 @@ func TestTemplateOverride(t *testing.T) {
t.Fatal(err)
}
}

func TestRenderFS(t *testing.T) {

tempDir := t.TempDir()

outDir := filepath.Join(tempDir, "output")

_ = os.Mkdir(outDir, 0o755)

f, err := os.CreateTemp(outDir, "gqlgen.go")
if err != nil {
t.Fatal(err)
}
defer f.Close()
defer os.RemoveAll(f.Name())
err = Render(Options{TemplateFS: templateFS, Filename: f.Name(), Packages: &code.Packages{}})
if err != nil {
t.Fatal(err)
}

expectedString := "package \n\nimport (\n)\nthis is my test package"
actualContents, _ := os.ReadFile(f.Name())
actualContentsStr := string(actualContents)

// don't look at last character since it's \n on Linux and \r\n on Windows
assert.Equal(t, expectedString, actualContentsStr[:len(expectedString)])
}
1 change: 1 addition & 0 deletions codegen/templates/test.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this is my test package
1 change: 1 addition & 0 deletions codegen/templates/test_.gotpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this will not be included
7 changes: 6 additions & 1 deletion plugin/federation/federation.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package federation

import (
_ "embed"
"fmt"
"sort"
"strings"
Expand All @@ -14,6 +15,9 @@ import (
"github.com/99designs/gqlgen/plugin/federation/fieldset"
)

//go:embed federation.gotpl
var federationTemplate string

type federation struct {
Entities []*Entity
Version int
Expand Down Expand Up @@ -85,7 +89,7 @@ func (f *federation) InjectSourceEarly() *ast.Source {
input := `
scalar _Any
scalar _FieldSet

directive @external on FIELD_DEFINITION
directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
Expand Down Expand Up @@ -274,6 +278,7 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
Data: f,
GeneratedHeader: true,
Packages: data.Config.Packages,
Template: federationTemplate,
})
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions plugin/modelgen/models.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package modelgen

import (
_ "embed"
"fmt"
"go/types"
"sort"
Expand All @@ -12,6 +13,9 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)

//go:embed models.gotpl
var modelTemplate string

type BuildMutateHook = func(b *ModelBuild) *ModelBuild

type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
Expand Down Expand Up @@ -269,6 +273,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
Data: b,
GeneratedHeader: true,
Packages: cfg.Packages,
Template: modelTemplate,
})
if err != nil {
return err
Expand Down
6 changes: 6 additions & 0 deletions plugin/resolvergen/resolver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resolvergen

import (
_ "embed"
"errors"
"io/fs"
"os"
Expand All @@ -14,6 +15,9 @@ import (
"github.com/99designs/gqlgen/plugin"
)

//go:embed resolver.gotpl
var resolverTemplate string

func New() plugin.Plugin {
return &Plugin{}
}
Expand Down Expand Up @@ -76,6 +80,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
Filename: data.Config.Resolver.Filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
}

Expand Down Expand Up @@ -143,6 +148,7 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
Filename: filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
if err != nil {
return err
Expand Down
Loading