diff --git a/protoc-gen-grpc-gateway/gengateway/generator.go b/protoc-gen-grpc-gateway/gengateway/generator.go index d08609bc729..cb2f5e14a9f 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator.go +++ b/protoc-gen-grpc-gateway/gengateway/generator.go @@ -78,6 +78,9 @@ func (g *generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGenerato return nil, err } name := file.GetName() + if file.GoPkg.Path != "" { + name = fmt.Sprintf("%s/%s", file.GoPkg.Path, filepath.Base(name)) + } ext := filepath.Ext(name) base := strings.TrimSuffix(name, ext) output := fmt.Sprintf("%s.pb.gw.go", base) diff --git a/protoc-gen-grpc-gateway/gengateway/generator_test.go b/protoc-gen-grpc-gateway/gengateway/generator_test.go index 755a09236e2..986ff4151e5 100644 --- a/protoc-gen-grpc-gateway/gengateway/generator_test.go +++ b/protoc-gen-grpc-gateway/gengateway/generator_test.go @@ -1,6 +1,7 @@ package gengateway import ( + "path/filepath" "strings" "testing" @@ -9,7 +10,16 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" ) -func TestGenerateServiceWithoutBindings(t *testing.T) { +func newExampleFileDescriptor() *descriptor.File { + return newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/to/example/example.pb", + Name: "example_pb", + }, + ) +} + +func newExampleFileDescriptorWithGoPkg(gp *descriptor.GoPackage) *descriptor.File { msgdesc := &protodescriptor.DescriptorProto{ Name: proto.String("ExampleMessage"), } @@ -39,7 +49,7 @@ func TestGenerateServiceWithoutBindings(t *testing.T) { Name: proto.String("ExampleService"), Method: []*protodescriptor.MethodDescriptorProto{meth, meth1}, } - file := descriptor.File{ + return &descriptor.File{ FileDescriptorProto: &protodescriptor.FileDescriptorProto{ Name: proto.String("example.proto"), Package: proto.String("example"), @@ -47,10 +57,7 @@ func TestGenerateServiceWithoutBindings(t *testing.T) { MessageType: []*protodescriptor.DescriptorProto{msgdesc}, Service: []*protodescriptor.ServiceDescriptorProto{svc}, }, - GoPkg: descriptor.GoPackage{ - Path: "example.com/path/to/example/example.pb", - Name: "example_pb", - }, + GoPkg: *gp, Messages: []*descriptor.Message{msg}, Services: []*descriptor.Service{ { @@ -76,8 +83,12 @@ func TestGenerateServiceWithoutBindings(t *testing.T) { }, }, } +} + +func TestGenerateServiceWithoutBindings(t *testing.T) { + file := newExampleFileDescriptor() g := &generator{} - got, err := g.generate(crossLinkFixture(&file)) + got, err := g.generate(crossLinkFixture(file)) if err != nil { t.Errorf("generate(%#v) failed with %v; want success", file, err) return @@ -86,3 +97,57 @@ func TestGenerateServiceWithoutBindings(t *testing.T) { t.Errorf("generate(%#v) = %s; does not want to contain %s", file, got, notwanted) } } + +func TestGenerateOutputPath(t *testing.T) { + cases := []struct { + file *descriptor.File + expected string + }{ + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example.com/path/to/example", + Name: "example_pb", + }, + ), + expected: "example.com/path/to/example", + }, + { + file: newExampleFileDescriptorWithGoPkg( + &descriptor.GoPackage{ + Path: "example", + Name: "example_pb", + }, + ), + expected: "example", + }, + } + + g := &generator{} + for _, c := range cases { + file := c.file + gots, err := g.Generate([]*descriptor.File{crossLinkFixture(file)}) + if err != nil { + t.Errorf("Generate(%#v) failed with %v; wants success", file, err) + return + } + + if len(gots) != 1 { + t.Errorf("Generate(%#v) failed; expects on result got %d", file, len(gots)) + return + } + + got := gots[0] + if got.Name == nil { + t.Errorf("Generate(%#v) failed; expects non-nil Name(%v)", file, got.Name) + return + } + + gotPath := filepath.Dir(*got.Name) + expectedPath := c.expected + if gotPath != expectedPath { + t.Errorf("Generate(%#v) failed; got path: %s expected path: %s", file, gotPath, expectedPath) + return + } + } +}