Skip to content

Commit

Permalink
feat: expose invalid argument error to clients in bidirectional strea…
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Lee authored and Ian Lee committed Oct 9, 2024
1 parent 435362f commit 10f6e39
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
28 changes: 19 additions & 9 deletions protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ var _ = metadata.Join

_ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.ReplaceAll(`
{{if .Method.GetServerStreaming}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error)
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.InstanceName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, chan error, error)
{{else}}
func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, marshaler runtime.Marshaler, client {{.Method.Service.InstanceName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error)
{{end}}`, "\n", "")))
Expand Down Expand Up @@ -439,10 +439,11 @@ var (
_ = template.Must(handlerTemplate.New("bidi-streaming-request-func").Parse(`
{{template "request-func-signature" .}} {
var metadata runtime.ServerMetadata
errChan := make(chan error)
stream, err := client.{{.Method.GetName}}(ctx)
if err != nil {
grpclog.Errorf("Failed to start streaming: %v", err)
return nil, metadata, err
return nil, metadata, errChan, err
}
dec := marshaler.NewDecoder(req.Body)
handleSend := func() error {
Expand All @@ -453,7 +454,7 @@ var (
}
if err != nil {
grpclog.Errorf("Failed to decode request: %v", err)
return err
return status.Errorf(codes.InvalidArgument, "Failed to decode request: %v", err)
}
if err := stream.Send(&protoReq); err != nil {
grpclog.Errorf("Failed to send request: %v", err)
Expand All @@ -464,20 +465,18 @@ var (
go func() {
for {
if err := handleSend(); err != nil {
errChan <- err
break
}
}
if err := stream.CloseSend(); err != nil {
grpclog.Errorf("Failed to terminate client stream: %v", err)
}
}()
header, err := stream.Header()
if err != nil {
grpclog.Errorf("Failed to get header from client: %v", err)
return nil, metadata, err
return nil, metadata, errChan, err
}
metadata.HeaderMD = header
return stream, metadata, nil
return stream, metadata, errChan, nil
}
`))

Expand Down Expand Up @@ -727,12 +726,23 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams)
resp, md, reqErrChan, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
go func() {
for {
err := <-reqErrChan
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
if err := resp.CloseSend(); err != nil {
grpclog.Errorf("Failed to terminate client stream: %v", err)
}
}
}
}()
{{if $m.GetServerStreaming}}
{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, func() (proto.Message, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) {
},
{
serverStreaming: true,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`,
sigWant: `func request_ExampleService_Echo_0(ctx context.Context, marshaler runtime.Marshaler, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, chan error, error) {`,
},
} {
meth.ServerStreaming = proto.Bool(spec.serverStreaming)
Expand Down

0 comments on commit 10f6e39

Please sign in to comment.