diff --git a/connect_ext_test.go b/connect_ext_test.go index f1cef5fb..63053090 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -396,6 +396,14 @@ func TestServer(t *testing.T) { connect.WithSendGzip(), ) }) + t.Run("json_get", func(t *testing.T) { + run( + t, + connect.WithProtoJSON(), + connect.WithHTTPGet(), + connect.WithHTTPGetMaxURLSize(1024, true), + ) + }) }) t.Run("grpc", func(t *testing.T) { t.Run("proto", func(t *testing.T) { diff --git a/error_writer.go b/error_writer.go index 773aa4e9..c708f426 100644 --- a/error_writer.go +++ b/error_writer.go @@ -21,6 +21,17 @@ import ( "strings" ) +// protocolType is one of the supported RPC protocols. +type protocolType uint8 + +const ( + unknownProtocol protocolType = iota + connectUnaryProtocol + connectStreamProtocol + grpcProtocol + grpcWebProtocol +) + // An ErrorWriter writes errors to an [http.ResponseWriter] in the format // expected by an RPC client. This is especially useful in server-side net/http // middleware, where you may wish to handle requests from RPC and non-RPC @@ -30,7 +41,6 @@ import ( type ErrorWriter struct { bufferPool *bufferPool protobuf Codec - allContentTypes map[string]struct{} grpcContentTypes map[string]struct{} grpcWebContentTypes map[string]struct{} unaryConnectContentTypes map[string]struct{} @@ -46,7 +56,6 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { writer := &ErrorWriter{ bufferPool: config.BufferPool, protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(), - allContentTypes: make(map[string]struct{}), grpcContentTypes: make(map[string]struct{}), grpcWebContentTypes: make(map[string]struct{}), unaryConnectContentTypes: make(map[string]struct{}), @@ -54,66 +63,89 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter { } for name := range config.Codecs { unary := connectContentTypeFromCodecName(StreamTypeUnary, name) - writer.allContentTypes[unary] = struct{}{} writer.unaryConnectContentTypes[unary] = struct{}{} streaming := connectContentTypeFromCodecName(StreamTypeBidi, name) writer.streamingConnectContentTypes[streaming] = struct{}{} - writer.allContentTypes[streaming] = struct{}{} } if config.HandleGRPC { writer.grpcContentTypes[grpcContentTypeDefault] = struct{}{} - writer.allContentTypes[grpcContentTypeDefault] = struct{}{} for name := range config.Codecs { ct := grpcContentTypeFromCodecName(false /* web */, name) writer.grpcContentTypes[ct] = struct{}{} - writer.allContentTypes[ct] = struct{}{} } } if config.HandleGRPCWeb { writer.grpcWebContentTypes[grpcWebContentTypeDefault] = struct{}{} - writer.allContentTypes[grpcWebContentTypeDefault] = struct{}{} for name := range config.Codecs { ct := grpcContentTypeFromCodecName(true /* web */, name) writer.grpcWebContentTypes[ct] = struct{}{} - writer.allContentTypes[ct] = struct{}{} } } return writer } +func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType { + ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) + if _, ok := w.unaryConnectContentTypes[ctype]; ok { + return connectUnaryProtocol + } + if _, ok := w.streamingConnectContentTypes[ctype]; ok { + return connectStreamProtocol + } + if _, ok := w.grpcContentTypes[ctype]; ok { + return grpcProtocol + } + if _, ok := w.grpcWebContentTypes[ctype]; ok { + return grpcWebProtocol + } + // Check for Connect-Protocol-Version header or connect protocol query + // parameter to support connect GET requests. + if request.Method == http.MethodGet { + connectVersion := getHeaderCanonical(request.Header, connectProtocolVersion) + if connectVersion == connectProtocolVersion { + return connectUnaryProtocol + } + connectVersion = request.URL.Query().Get(connectUnaryConnectQueryParameter) + if connectVersion == connectUnaryConnectQueryValue { + return connectUnaryProtocol + } + } + return unknownProtocol +} + // IsSupported checks whether a request is using one of the ErrorWriter's // supported RPC protocols. func (w *ErrorWriter) IsSupported(request *http.Request) bool { - ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - _, ok := w.allContentTypes[ctype] - return ok + return w.classifyRequest(request) != unknownProtocol } // Write an error, using the format appropriate for the RPC protocol in use. // Callers should first use IsSupported to verify that the request is using one -// of the ErrorWriter's supported RPC protocols. +// of the ErrorWriter's supported RPC protocols. If the protocol is unknown, +// Write will send the error as unprefixed, Connect-formatted JSON. // // Write does not read or close the request body. func (w *ErrorWriter) Write(response http.ResponseWriter, request *http.Request, err error) error { ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType)) - if _, ok := w.unaryConnectContentTypes[ctype]; ok { - // Unary errors are always JSON. - setHeaderCanonical(response.Header(), headerContentType, connectUnaryContentTypeJSON) - return w.writeConnectUnary(response, err) - } - if _, ok := w.streamingConnectContentTypes[ctype]; ok { + switch protocolType := w.classifyRequest(request); protocolType { + case connectStreamProtocol: setHeaderCanonical(response.Header(), headerContentType, ctype) return w.writeConnectStreaming(response, err) - } - if _, ok := w.grpcContentTypes[ctype]; ok { + case grpcProtocol: setHeaderCanonical(response.Header(), headerContentType, ctype) return w.writeGRPC(response, err) - } - if _, ok := w.grpcWebContentTypes[ctype]; ok { + case grpcWebProtocol: setHeaderCanonical(response.Header(), headerContentType, ctype) return w.writeGRPCWeb(response, err) + case unknownProtocol, connectUnaryProtocol: + fallthrough + default: + // Unary errors are always JSON. Unknown protocols are treated as unary + // because they are likely to be Connect clients and will still be able to + // parse the error as it's in a human-readable format. + setHeaderCanonical(response.Header(), headerContentType, connectUnaryContentTypeJSON) + return w.writeConnectUnary(response, err) } - return fmt.Errorf("unsupported Content-Type %q", ctype) } func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) error {