Skip to content

Commit

Permalink
Use importPath to set package name rather than package path. (#537)
Browse files Browse the repository at this point in the history
* Prioritise go_package option over import_path argument.
* Adding registry tests for SetImportPath method.
  • Loading branch information
rwlincoln authored and achew22 committed Feb 23, 2018
1 parent 424b8e1 commit 5c79dbf
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
19 changes: 11 additions & 8 deletions protoc-gen-grpc-gateway/descriptor/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (r *Registry) Load(req *plugin.CodeGeneratorRequest) error {
if target == nil {
return fmt.Errorf("no such file: %s", name)
}
name := packageIdentityName(target.FileDescriptorProto)
name := r.packageIdentityName(target.FileDescriptorProto)
if targetPkg == "" {
targetPkg = name
} else {
Expand All @@ -83,7 +83,7 @@ func (r *Registry) Load(req *plugin.CodeGeneratorRequest) error {
func (r *Registry) loadFile(file *descriptor.FileDescriptorProto) {
pkg := GoPackage{
Path: r.goPackagePath(file),
Name: defaultGoPackageName(file),
Name: r.defaultGoPackageName(file),
}
if err := r.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil {
for i := 0; ; i++ {
Expand Down Expand Up @@ -247,9 +247,6 @@ func (r *Registry) goPackagePath(f *descriptor.FileDescriptorProto) string {
}

gopkg := f.Options.GetGoPackage()
if len(gopkg) == 0 {
gopkg = r.importPath
}
idx := strings.LastIndex(gopkg, "/")
if idx >= 0 {
if sc := strings.LastIndex(gopkg, ";"); sc > 0 {
Expand Down Expand Up @@ -295,15 +292,15 @@ func sanitizePackageName(pkgName string) string {

// defaultGoPackageName returns the default go package name to be used for go files generated from "f".
// You might need to use an unique alias for the package when you import it. Use ReserveGoPackageAlias to get a unique alias.
func defaultGoPackageName(f *descriptor.FileDescriptorProto) string {
name := packageIdentityName(f)
func (r *Registry) defaultGoPackageName(f *descriptor.FileDescriptorProto) string {
name := r.packageIdentityName(f)
return sanitizePackageName(name)
}

// packageIdentityName returns the identity of packages.
// protoc-gen-grpc-gateway rejects CodeGenerationRequests which contains more than one packages
// as protoc-gen-go does.
func packageIdentityName(f *descriptor.FileDescriptorProto) string {
func (r *Registry) packageIdentityName(f *descriptor.FileDescriptorProto) string {
if f.Options != nil && f.Options.GoPackage != nil {
gopkg := f.Options.GetGoPackage()
idx := strings.LastIndex(gopkg, "/")
Expand All @@ -321,6 +318,12 @@ func packageIdentityName(f *descriptor.FileDescriptorProto) string {
}
return sanitizePackageName(gopkg[sc+1:])
}
if p := r.importPath; len(p) != 0 {
if i := strings.LastIndex(p, "/"); i >= 0 {
p = p[i+1:]
}
return p
}

if f.Package == nil {
base := filepath.Base(f.GetName())
Expand Down
37 changes: 37 additions & 0 deletions protoc-gen-grpc-gateway/descriptor/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,40 @@ func TestLoadOverridedPackageName(t *testing.T) {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}

func TestLoadSetInputPath(t *testing.T) {
reg := NewRegistry()
reg.SetImportPath("foo/examplepb")
loadFile(t, reg, `
name: 'example.proto'
package: 'example'
`)
file := reg.files["example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: ".", Name: "examplepb"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}

func TestLoadGoPackageInputPath(t *testing.T) {
reg := NewRegistry()
reg.SetImportPath("examplepb")
loadFile(t, reg, `
name: 'example.proto'
package: 'example'
options < go_package: 'example.com/xyz;pb' >
`)
file := reg.files["example.proto"]
if file == nil {
t.Errorf("reg.files[%q] = nil; want non-nil", "example.proto")
return
}
wantPkg := GoPackage{Path: "example.com/xyz", Name: "pb"}
if got, want := file.GoPkg, wantPkg; got != want {
t.Errorf("file.GoPkg = %#v; want %#v", got, want)
}
}

0 comments on commit 5c79dbf

Please sign in to comment.