Skip to content

Commit

Permalink
Support vendor imports
Browse files Browse the repository at this point in the history
 * allow specifying directory of import
  • Loading branch information
boz committed Jun 7, 2017
1 parent 43306ba commit 2d98cef
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
42 changes: 29 additions & 13 deletions impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/build"
Expand All @@ -19,29 +20,36 @@ import (
"golang.org/x/tools/imports"
)

const usage = `impl <recv> <iface>
const usage = `impl [-dir directory] <recv> <iface>
impl generates method stubs for recv to implement iface.
Examples:
impl 'f *File' io.Reader
impl Murmur hash.Hash
impl Murmur hash.Hash -dir $GOPATH/src/github.com/josharian/impl
Don't forget the single quotes around the receiver type
to prevent shell globbing.
`

var (
sourceDirectory = flag.String("dir", "", "package source directory")
)

// findInterface returns the import path and identifier of an interface.
// For example, given "http.ResponseWriter", findInterface returns
// "net/http", "ResponseWriter".
// If a fully qualified interface is given, such as "net/http.ResponseWriter",
// it simply parses the input.
func findInterface(iface string) (path string, id string, err error) {
func findInterface(iface string, srcDir string) (path string, id string, err error) {
if len(strings.Fields(iface)) != 1 {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
}

srcPath := filepath.Join(srcDir, "__go_impl__.go")

if slash := strings.LastIndex(iface, "/"); slash > -1 {
// package path provided
dot := strings.LastIndex(iface, ".")
Expand All @@ -63,15 +71,15 @@ func findInterface(iface string) (path string, id string, err error) {
src := []byte("package hack\n" + "var i " + iface)
// If we couldn't determine the import path, goimports will
// auto fix the import path.
imp, err := imports.Process(".", src, nil)
imp, err := imports.Process(srcPath, src, nil)
if err != nil {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
}

// imp should now contain an appropriate import.
// Parse out the import and the identifier.
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", imp, 0)
f, err := parser.ParseFile(fset, srcPath, imp, 0)
if err != nil {
panic(err)
}
Expand All @@ -97,8 +105,8 @@ type Pkg struct {
}

// typeSpec locates the *ast.TypeSpec for type id in the import path.
func typeSpec(path string, id string) (Pkg, *ast.TypeSpec, error) {
pkg, err := build.Import(path, "", 0)
func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) {
pkg, err := build.Import(path, srcDir, 0)
if err != nil {
return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err)
}
Expand Down Expand Up @@ -210,15 +218,15 @@ func (p Pkg) funcsig(f *ast.Field) Func {
// funcs returns the set of methods required to implement iface.
// It is called funcs rather than methods because the
// function descriptions are functions; there is no receiver.
func funcs(iface string) ([]Func, error) {
func funcs(iface string, srcDir string) ([]Func, error) {
// Locate the interface.
path, id, err := findInterface(iface)
path, id, err := findInterface(iface, srcDir)
if err != nil {
return nil, err
}

// Parse the package and find the interface declaration.
p, spec, err := typeSpec(path, id)
p, spec, err := typeSpec(path, id, srcDir)
if err != nil {
return nil, fmt.Errorf("interface %s not found: %s", iface, err)
}
Expand All @@ -235,7 +243,7 @@ func funcs(iface string) ([]Func, error) {
for _, fndecl := range idecl.Methods.List {
if len(fndecl.Names) == 0 {
// Embedded interface: recurse
embedded, err := funcs(p.fullType(fndecl.Type))
embedded, err := funcs(p.fullType(fndecl.Type), srcDir)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -287,16 +295,24 @@ func validReceiver(recv string) bool {
}

func main() {
if len(os.Args) != 3 {
flag.Parse()

if len(flag.Args()) < 2 {
fmt.Fprint(os.Stderr, usage)
os.Exit(2)
}
recv, iface := os.Args[1], os.Args[2]

recv, iface := flag.Args()[0], flag.Args()[1]
if !validReceiver(recv) {
fatal(fmt.Sprintf("invalid receiver: %q", recv))
}

fns, err := funcs(iface)
if *sourceDirectory == "" {
*sourceDirectory, _ = os.Getwd()
}
fmt.Println(*sourceDirectory)

fns, err := funcs(iface, *sourceDirectory)
if err != nil {
fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestFindInterface(t *testing.T) {
}

for _, tt := range cases {
path, id, err := findInterface(tt.iface)
path, id, err := findInterface(tt.iface, ".")
gotErr := err != nil
if tt.wantErr != gotErr {
t.Errorf("findInterface(%q).err=%v want %s", tt.iface, err, errBool(tt.wantErr))
Expand All @@ -61,7 +61,7 @@ func TestTypeSpec(t *testing.T) {
}

for _, tt := range cases {
p, spec, err := typeSpec(tt.path, tt.id)
p, spec, err := typeSpec(tt.path, tt.id, "")
gotErr := err != nil
if tt.wantErr != gotErr {
t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.id, err, errBool(tt.wantErr))
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestFuncs(t *testing.T) {
}

for _, tt := range cases {
fns, err := funcs(tt.iface)
fns, err := funcs(tt.iface, "")
gotErr := err != nil
if tt.wantErr != gotErr {
t.Errorf("funcs(%q).err=%v want %s", tt.iface, err, errBool(tt.wantErr))
Expand Down

0 comments on commit 2d98cef

Please sign in to comment.