From 9410da471ab53f71cf8aea1e5a2acb9eb0c27c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Phil=20Cal=C3=A7ado?= Date: Thu, 25 Jan 2018 11:49:38 -0500 Subject: [PATCH] Better error handling for Tap (#177) Previously, running `$conduit tap` would return a `Unexpected EOF` error when the server wasn't available. This was due to a few problems with the way we were handling errors all the way down the tap server. This change fixes that and cleans some of the protobuf-over-HTTP code. - first step towards #49 - closes #106 --- cli/cmd/root.go | 2 +- cli/cmd/tap.go | 6 +- controller/api/public/client.go | 137 +++-- controller/api/public/client_test.go | 83 +-- controller/api/public/grpc_server.go | 7 +- controller/api/public/http_server.go | 217 ++++++++ controller/api/public/http_server_test.go | 207 ++++++++ controller/api/public/proto_over_http.go | 150 ++++++ controller/api/public/proto_over_http_test.go | 498 ++++++++++++++++++ controller/api/public/server.go | 273 ---------- controller/k8s/pods.go | 2 +- controller/k8s/replicasets.go | 2 +- controller/tap/server.go | 6 +- 13 files changed, 1176 insertions(+), 414 deletions(-) create mode 100644 controller/api/public/http_server.go create mode 100644 controller/api/public/http_server_test.go create mode 100644 controller/api/public/proto_over_http.go create mode 100644 controller/api/public/proto_over_http_test.go delete mode 100644 controller/api/public/server.go diff --git a/cli/cmd/root.go b/cli/cmd/root.go index fe7debccb5441..de5d13aa5a129 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -40,8 +40,8 @@ func init() { func addControlPlaneNetworkingArgs(cmd *cobra.Command) { // Use the same argument name as `kubectl` (see the output of `kubectl options`). + //TODO: move these to init() as they are globally applicable cmd.PersistentFlags().StringVar(&kubeconfigPath, "kubeconfig", "", "Path to the kubeconfig file to use for CLI requests") - cmd.PersistentFlags().StringVar(&apiAddr, "api-addr", "", "Override kubeconfig and communicate directly with the control plane at host:port (mostly for testing)") } diff --git a/cli/cmd/tap.go b/cli/cmd/tap.go index b5f717638e932..a83e242972902 100644 --- a/cli/cmd/tap.go +++ b/cli/cmd/tap.go @@ -38,7 +38,7 @@ var tapCmd = &cobra.Command{ Valid targets include: * Pods (default/hello-world-h4fb2) * Deployments (default/hello-world)`, - RunE: func(cmd *cobra.Command, args []string) error { + Run: exitSilentlyOnError(func(cmd *cobra.Command, args []string) error { if len(args) != 2 { return errors.New("please specify a target") } @@ -68,7 +68,7 @@ Valid targets include: } return requestTapFromApi(os.Stdout, client, args[1], validatedResourceType, partialReq) - }, + }), } func init() { @@ -102,6 +102,7 @@ func requestTapFromApi(w io.Writer, client pb.ApiClient, targetName string, reso rsp, err := client.Tap(context.Background(), req) if err != nil { + fmt.Fprintln(w, err.Error()) return err } @@ -117,7 +118,6 @@ func renderTap(w io.Writer, tapClient pb.Api_TapClient) error { tableWriter.Flush() return nil - } func writeTapEventsToBuffer(tapClient pb.Api_TapClient, w *tabwriter.Writer) error { diff --git a/controller/api/public/client.go b/controller/api/public/client.go index 5281f25bc14b4..6ee4be385a06f 100644 --- a/controller/api/public/client.go +++ b/controller/api/public/client.go @@ -3,9 +3,7 @@ package public import ( "bufio" "bytes" - "encoding/binary" "fmt" - "io" "net/http" "net/url" @@ -23,10 +21,7 @@ import ( const ( ApiRoot = "/" // Must be absolute (with a leading slash). ApiVersion = "v1" - JsonContentType = "application/json" ApiPrefix = "api/" + ApiVersion + "/" // Must be relative (without a leading slash). - ProtobufContentType = "application/octet-stream" - ErrorHeader = "conduit-error" ConduitApiSubsystemName = "conduit-api" ) @@ -35,11 +30,6 @@ type grpcOverHttpClient struct { httpClient *http.Client } -type tapClient struct { - ctx context.Context - reader *bufio.Reader -} - func (c *grpcOverHttpClient) Stat(ctx context.Context, req *pb.MetricRequest, _ ...grpc.CallOption) (*pb.MetricResponse, error) { var msg pb.MetricResponse err := c.apiRequest(ctx, "Stat", req, &msg) @@ -66,35 +56,26 @@ func (c *grpcOverHttpClient) ListPods(ctx context.Context, req *pb.Empty, _ ...g func (c *grpcOverHttpClient) Tap(ctx context.Context, req *pb.TapRequest, _ ...grpc.CallOption) (pb.Api_TapClient, error) { url := c.endpointNameToPublicApiUrl("Tap") - log.Debugf("Making streaming gRPC-over-HTTP call to [%s]", url.String()) - rsp, err := c.post(ctx, url, req) + httpRsp, err := c.post(ctx, url, req) if err != nil { return nil, err } + if err = checkIfResponseHasConduitError(httpRsp); err != nil { + httpRsp.Body.Close() + return nil, err + } + go func() { <-ctx.Done() log.Debug("Closing response body after context marked as done") - rsp.Body.Close() + httpRsp.Body.Close() }() - return &tapClient{ctx: ctx, reader: bufio.NewReader(rsp.Body)}, nil + return &tapClient{ctx: ctx, reader: bufio.NewReader(httpRsp.Body)}, nil } -func (c tapClient) Recv() (*common.TapEvent, error) { - var msg common.TapEvent - err := fromByteStreamToProtocolBuffers(c.reader, "", &msg) - return &msg, err -} - -// satisfy the pb.Api_TapClient interface -func (c tapClient) Header() (metadata.MD, error) { return nil, nil } -func (c tapClient) Trailer() metadata.MD { return nil } -func (c tapClient) CloseSend() error { return nil } -func (c tapClient) Context() context.Context { return c.ctx } -func (c tapClient) SendMsg(interface{}) error { return nil } -func (c tapClient) RecvMsg(interface{}) error { return nil } -func (c *grpcOverHttpClient) apiRequest(ctx context.Context, endpoint string, req proto.Message, rsp proto.Message) error { +func (c *grpcOverHttpClient) apiRequest(ctx context.Context, endpoint string, req proto.Message, protoResponse proto.Message) error { url := c.endpointNameToPublicApiUrl(endpoint) log.Debugf("Making gRPC-over-HTTP call to [%s]", url.String()) @@ -102,18 +83,20 @@ func (c *grpcOverHttpClient) apiRequest(ctx context.Context, endpoint string, re if err != nil { return err } - + defer httpRsp.Body.Close() log.Debugf("gRPC-over-HTTP call returned status [%s] and content length [%d]", httpRsp.Status, httpRsp.ContentLength) + clientSideErrorStatusCode := httpRsp.StatusCode >= 400 && httpRsp.StatusCode <= 499 if clientSideErrorStatusCode { return fmt.Errorf("POST to Conduit API endpoint [%s] returned HTTP status [%s]", url, httpRsp.Status) } - defer httpRsp.Body.Close() + if err = checkIfResponseHasConduitError(httpRsp); err != nil { + return err + } reader := bufio.NewReader(httpRsp.Body) - errorMsg := httpRsp.Header.Get(ErrorHeader) - return fromByteStreamToProtocolBuffers(reader, errorMsg, rsp) + return fromByteStreamToProtocolBuffers(reader, protoResponse) } func (c *grpcOverHttpClient) post(ctx context.Context, url *url.URL, req proto.Message) (*http.Response, error) { @@ -131,35 +114,51 @@ func (c *grpcOverHttpClient) post(ctx context.Context, url *url.URL, req proto.M return nil, err } - return c.httpClient.Do(httpReq.WithContext(ctx)) + rsp, err := c.httpClient.Do(httpReq.WithContext(ctx)) + if err != nil { + log.Debugf("Error invoking [%s]: %v", url.String(), err) + } else { + log.Debugf("Response from [%s] had headers: %v", url.String(), rsp.Header) + } + + return rsp, err } func (c *grpcOverHttpClient) endpointNameToPublicApiUrl(endpoint string) *url.URL { return c.serverURL.ResolveReference(&url.URL{Path: endpoint}) } -func NewInternalClient(kubernetesApiHost string) (pb.ApiClient, error) { - apiURL := &url.URL{ - Scheme: "http", - Host: kubernetesApiHost, - Path: "/", - } +type tapClient struct { + ctx context.Context + reader *bufio.Reader +} - return newClient(apiURL, http.DefaultClient) +func (c tapClient) Recv() (*common.TapEvent, error) { + var msg common.TapEvent + err := fromByteStreamToProtocolBuffers(c.reader, &msg) + return &msg, err } -func NewExternalClient(controlPlaneNamespace string, kubeApi k8s.KubernetesApi) (pb.ApiClient, error) { - apiURL, err := kubeApi.UrlFor(controlPlaneNamespace, "/services/http:api:http/proxy/") +// satisfy the pb.Api_TapClient interface +func (c tapClient) Header() (metadata.MD, error) { return nil, nil } +func (c tapClient) Trailer() metadata.MD { return nil } +func (c tapClient) CloseSend() error { return nil } +func (c tapClient) Context() context.Context { return c.ctx } +func (c tapClient) SendMsg(interface{}) error { return nil } +func (c tapClient) RecvMsg(interface{}) error { return nil } + +func fromByteStreamToProtocolBuffers(byteStreamContainingMessage *bufio.Reader, out proto.Message) error { + messageAsBytes, err := deserializePayloadFromReader(byteStreamContainingMessage) if err != nil { - return nil, err + return fmt.Errorf("error reading byte stream header: %v", err) } - httpClientToUse, err := kubeApi.NewClient() + err = proto.Unmarshal(messageAsBytes, out) if err != nil { - return nil, err + return fmt.Errorf("error unmarshalling array of [%d] bytes error: %v", len(messageAsBytes), err) } - return newClient(apiURL, httpClientToUse) + return nil } func newClient(apiURL *url.URL, httpClientToUse *http.Client) (pb.ApiClient, error) { @@ -168,42 +167,36 @@ func newClient(apiURL *url.URL, httpClientToUse *http.Client) (pb.ApiClient, err return nil, fmt.Errorf("server URL must be absolute, was [%s]", apiURL.String()) } + serverUrl := apiURL.ResolveReference(&url.URL{Path: ApiPrefix}) + + log.Debugf("Expecting Conduit Public API to be served over [%s]", serverUrl) + return &grpcOverHttpClient{ - serverURL: apiURL.ResolveReference(&url.URL{Path: ApiPrefix}), + serverURL: serverUrl, httpClient: httpClientToUse, }, nil } -func fromByteStreamToProtocolBuffers(byteStreamContainingMessage *bufio.Reader, errorMessageReturnedAsMetadata string, out proto.Message) error { - //TODO: why the magic number 4? - byteSize := make([]byte, 4) - - //TODO: why is this necessary? - _, err := byteStreamContainingMessage.Read(byteSize) - if err != nil { - return fmt.Errorf("error reading byte stream header: %v", err) +func NewInternalClient(kubernetesApiHost string) (pb.ApiClient, error) { + apiURL := &url.URL{ + Scheme: "http", + Host: kubernetesApiHost, + Path: "/", } - size := binary.LittleEndian.Uint32(byteSize) - bytes := make([]byte, size) - _, err = io.ReadFull(byteStreamContainingMessage, bytes) - if err != nil { - return fmt.Errorf("error reading byte stream content: %v", err) - } + return newClient(apiURL, http.DefaultClient) +} - if errorMessageReturnedAsMetadata != "" { - var apiError pb.ApiError - err = proto.Unmarshal(bytes, &apiError) - if err != nil { - return fmt.Errorf("error unmarshalling error from byte stream: %v", err) - } - return fmt.Errorf("%s: %s", errorMessageReturnedAsMetadata, apiError.Error) +func NewExternalClient(controlPlaneNamespace string, kubeApi k8s.KubernetesApi) (pb.ApiClient, error) { + apiURL, err := kubeApi.UrlFor(controlPlaneNamespace, "/services/http:api:http/proxy/") + if err != nil { + return nil, err } - err = proto.Unmarshal(bytes, out) + httpClientToUse, err := kubeApi.NewClient() if err != nil { - return fmt.Errorf("error unmarshalling bytes: %v", err) - } else { - return nil + return nil, err } + + return newClient(apiURL, httpClientToUse) } diff --git a/controller/api/public/client_test.go b/controller/api/public/client_test.go index c434348d6a169..01ac0e04511b3 100644 --- a/controller/api/public/client_test.go +++ b/controller/api/public/client_test.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "context" - "encoding/binary" "io/ioutil" "net/http" "net/url" @@ -60,6 +59,7 @@ func TestNewInternalClient(t *testing.T) { } func TestFromByteStreamToProtocolBuffers(t *testing.T) { + t.Run("Correctly marshalls an valid object", func(t *testing.T) { versionInfo := pb.VersionInfo{ GoVersion: "1.9.1", @@ -70,9 +70,9 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { var protobufMessageToBeFilledWithData pb.VersionInfo reader := bufferedReader(t, &versionInfo) - err := fromByteStreamToProtocolBuffers(reader, "", &protobufMessageToBeFilledWithData) + err := fromByteStreamToProtocolBuffers(reader, &protobufMessageToBeFilledWithData) if err != nil { - t.Fatal(err.Error()) + t.Fatalf("Unexpected error: %v", err) } if protobufMessageToBeFilledWithData != versionInfo { @@ -80,14 +80,14 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { } }) - t.Run("Correctly marshalls a large byte arrey", func(t *testing.T) { + t.Run("Correctly marshalls a large byte array", func(t *testing.T) { series := pb.MetricSeries{ Name: pb.MetricName_REQUEST_RATE, Metadata: &pb.MetricMetadata{}, Datapoints: make([]*pb.MetricDatapoint, 0), } - numberOfDatapointsInMessage := 1000 + numberOfDatapointsInMessage := 400 for i := 0; i < numberOfDatapointsInMessage; i++ { datapoint := pb.MetricDatapoint{ Value: &pb.MetricValue{Value: &pb.MetricValue_Gauge{Gauge: float64(i)}}, @@ -95,12 +95,12 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { } series.Datapoints = append(series.Datapoints, &datapoint) } - - var protobufMessageToBeFilledWithData pb.MetricSeries reader := bufferedReader(t, &series) - err := fromByteStreamToProtocolBuffers(reader, "", &protobufMessageToBeFilledWithData) + + protobufMessageToBeFilledWithData := &pb.MetricSeries{} + err := fromByteStreamToProtocolBuffers(reader, protobufMessageToBeFilledWithData) if err != nil { - t.Fatal(err.Error()) + t.Fatalf("Unexpected error: %v", err) } actualNumberOfDatapointsMarshalled := len(protobufMessageToBeFilledWithData.Datapoints) @@ -109,29 +109,12 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { } }) - t.Run("When error, uses both byte array and supplied message to return error", func(t *testing.T) { - apiError := pb.ApiError{Error: "an error occurred"} - - var protobufMessageToBeFilledWithData pb.VersionInfo - reader := bufferedReader(t, &apiError) - err := fromByteStreamToProtocolBuffers(reader, "Bad Request", &protobufMessageToBeFilledWithData) - if err == nil { - t.Fatal("expected error") - } - - expectedErrorMessage := "Bad Request: an error occurred" - actualErrorMessage := err.Error() - if actualErrorMessage != expectedErrorMessage { - t.Fatalf("Expecting returned error message to be [%s], but got [%s]", expectedErrorMessage, actualErrorMessage) - } - }) - - t.Run("When byte array contains error but no message was supplied, treats stream as regular object", func(t *testing.T) { + t.Run("When byte array contains error, treats stream as regular protobuf object", func(t *testing.T) { apiError := pb.ApiError{Error: "an error occurred"} var protobufMessageToBeFilledWithData pb.ApiError reader := bufferedReader(t, &apiError) - err := fromByteStreamToProtocolBuffers(reader, "", &protobufMessageToBeFilledWithData) + err := fromByteStreamToProtocolBuffers(reader, &protobufMessageToBeFilledWithData) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -143,39 +126,17 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { } }) - t.Run("When byte array does not contain error but a message was supplied, returns error", func(t *testing.T) { - versionInfo := pb.VersionInfo{ - GoVersion: "1.9.1", - BuildDate: "2017.11.17", - ReleaseVersion: "1.2.3", - } - - expectedErrorMessage := "supplied error message here" - var protobufMessageToBeFilledWithData pb.VersionInfo - reader := bufferedReader(t, &versionInfo) - - err := fromByteStreamToProtocolBuffers(reader, expectedErrorMessage, &protobufMessageToBeFilledWithData) - if err == nil { - t.Fatal("Expecting error, got nothing") - } - - actualErrorMessage := err.Error() - if !strings.Contains(actualErrorMessage, expectedErrorMessage) { - t.Fatalf("Expected object to contain message [%s], but got [%s]", expectedErrorMessage, actualErrorMessage) - } - }) - - t.Run("Correctly marshalls an valid object", func(t *testing.T) { - versionInfo := pb.VersionInfo{ + t.Run("Returns error if byte stream contains wrong object", func(t *testing.T) { + versionInfo := &pb.VersionInfo{ GoVersion: "1.9.1", BuildDate: "2017.11.17", ReleaseVersion: "1.2.3", } - var protobufMessageToBeFilledWithData pb.MetricSeries - reader := bufferedReader(t, &versionInfo) + reader := bufferedReader(t, versionInfo) - err := fromByteStreamToProtocolBuffers(reader, "", &protobufMessageToBeFilledWithData) + protobufMessageToBeFilledWithData := &pb.MetricSeries{} + err := fromByteStreamToProtocolBuffers(reader, protobufMessageToBeFilledWithData) if err == nil { t.Fatal("Expecting error, got nothing") } @@ -185,9 +146,13 @@ func TestFromByteStreamToProtocolBuffers(t *testing.T) { func bufferedReader(t *testing.T, msg proto.Message) *bufio.Reader { msgBytes, err := proto.Marshal(msg) if err != nil { - t.Fatal(err.Error()) + t.Fatalf("Unexpected error: %v", err) } - sizeBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(sizeBytes, uint32(len(msgBytes))) - return bufio.NewReader(bytes.NewReader(append(sizeBytes, msgBytes...))) + + payload, err := serializeAsPayload(msgBytes) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + return bufio.NewReader(bytes.NewReader(payload)) } diff --git a/controller/api/public/grpc_server.go b/controller/api/public/grpc_server.go index 048c9798edde9..6532419db5e0f 100644 --- a/controller/api/public/grpc_server.go +++ b/controller/api/public/grpc_server.go @@ -14,6 +14,7 @@ import ( tapPb "github.com/runconduit/conduit/controller/gen/controller/tap" telemPb "github.com/runconduit/conduit/controller/gen/controller/telemetry" pb "github.com/runconduit/conduit/controller/gen/public" + log "github.com/sirupsen/logrus" "golang.org/x/net/context" ) @@ -156,8 +157,10 @@ func (s *grpcServer) SelfCheck(ctx context.Context, in *healthcheckPb.SelfCheckR // Pass through to tap service func (s *grpcServer) Tap(req *pb.TapRequest, stream pb.Api_TapServer) error { tapStream := stream.(tapServer) - rsp, err := s.tapClient.Tap(tapStream.Context(), req) + tapClient, err := s.tapClient.Tap(tapStream.Context(), req) if err != nil { + //TODO: why not return the error? + log.Errorf("Unexpected error tapping [%v]: %v", req, err) return nil } for { @@ -165,7 +168,7 @@ func (s *grpcServer) Tap(req *pb.TapRequest, stream pb.Api_TapServer) error { case <-tapStream.Context().Done(): return nil default: - event, err := rsp.Recv() + event, err := tapClient.Recv() if err != nil { return err } diff --git a/controller/api/public/http_server.go b/controller/api/public/http_server.go new file mode 100644 index 0000000000000..1b9bbc6a5fe82 --- /dev/null +++ b/controller/api/public/http_server.go @@ -0,0 +1,217 @@ +package public + +import ( + "fmt" + "net/http" + + "github.com/golang/protobuf/jsonpb" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + common "github.com/runconduit/conduit/controller/gen/common" + healthcheckPb "github.com/runconduit/conduit/controller/gen/common/healthcheck" + tapPb "github.com/runconduit/conduit/controller/gen/controller/tap" + telemPb "github.com/runconduit/conduit/controller/gen/controller/telemetry" + pb "github.com/runconduit/conduit/controller/gen/public" + log "github.com/sirupsen/logrus" + "golang.org/x/net/context" + "google.golang.org/grpc/metadata" +) + +var ( + jsonMarshaler = jsonpb.Marshaler{EmitDefaults: true} + jsonUnmarshaler = jsonpb.Unmarshaler{} + statPath = fullUrlPathFor("Stat") + versionPath = fullUrlPathFor("Version") + listPodsPath = fullUrlPathFor("ListPods") + tapPath = fullUrlPathFor("Tap") + selfCheckPath = fullUrlPathFor("SelfCheck") +) + +type handler struct { + grpcServer pb.ApiServer +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + log.Debugf("Serving %s %s", req.Method, req.URL.Path) + // Validate request method + if req.Method != http.MethodPost { + writeErrorToHttpResponse(w, fmt.Errorf("POST required")) + return + } + + // Serve request + switch req.URL.Path { + case statPath: + h.handleStat(w, req) + case versionPath: + h.handleVersion(w, req) + case listPodsPath: + h.handleListPods(w, req) + case tapPath: + h.handleTap(w, req) + case selfCheckPath: + h.handleSelfCheck(w, req) + default: + http.NotFound(w, req) + } + +} + +func (h *handler) handleStat(w http.ResponseWriter, req *http.Request) { + var protoRequest pb.MetricRequest + err := httpRequestToProto(req, &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + rsp, err := h.grpcServer.Stat(req.Context(), &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + err = writeProtoToHttpResponse(w, rsp) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } +} + +func (h *handler) handleVersion(w http.ResponseWriter, req *http.Request) { + var protoRequest pb.Empty + err := httpRequestToProto(req, &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + rsp, err := h.grpcServer.Version(req.Context(), &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + err = writeProtoToHttpResponse(w, rsp) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } +} + +func (h *handler) handleSelfCheck(w http.ResponseWriter, req *http.Request) { + var protoRequest healthcheckPb.SelfCheckRequest + err := httpRequestToProto(req, &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + rsp, err := h.grpcServer.SelfCheck(req.Context(), &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + err = writeProtoToHttpResponse(w, rsp) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } +} + +func (h *handler) handleListPods(w http.ResponseWriter, req *http.Request) { + var protoRequest pb.Empty + err := httpRequestToProto(req, &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + rsp, err := h.grpcServer.ListPods(req.Context(), &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + err = writeProtoToHttpResponse(w, rsp) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } +} + +func (h *handler) handleTap(w http.ResponseWriter, req *http.Request) { + flushableWriter, err := newStreamingWriter(w) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + var protoRequest pb.TapRequest + err = httpRequestToProto(req, &protoRequest) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } + + server := tapServer{w: flushableWriter, req: req} + err = h.grpcServer.Tap(&protoRequest, server) + if err != nil { + writeErrorToHttpResponse(w, err) + return + } +} + +type tapServer struct { + w flushableResponseWriter + req *http.Request +} + +func (s tapServer) Send(msg *common.TapEvent) error { + err := writeProtoToHttpResponse(s.w, msg) + if err != nil { + writeErrorToHttpResponse(s.w, err) + return err + } + + s.w.Flush() + return nil +} + +// satisfy the pb.Api_TapServer interface +func (s tapServer) SetHeader(metadata.MD) error { return nil } +func (s tapServer) SendHeader(metadata.MD) error { return nil } +func (s tapServer) SetTrailer(metadata.MD) { return } +func (s tapServer) Context() context.Context { return s.req.Context() } +func (s tapServer) SendMsg(interface{}) error { return nil } +func (s tapServer) RecvMsg(interface{}) error { return nil } + +func fullUrlPathFor(method string) string { + return ApiRoot + ApiPrefix + method +} + +func withTelemetry(baseHandler *handler) http.HandlerFunc { + counter := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "A counter for requests to the wrapped handler.", + }, + []string{"code", "method"}, + ) + prometheus.MustRegister(counter) + return promhttp.InstrumentHandlerCounter(counter, baseHandler) +} + +func NewServer(addr string, telemetryClient telemPb.TelemetryClient, tapClient tapPb.TapClient) *http.Server { + baseHandler := &handler{ + grpcServer: newGrpcServer(telemetryClient, tapClient), + } + + instrumentedHandler := withTelemetry(baseHandler) + + return &http.Server{ + Addr: addr, + Handler: instrumentedHandler, + } +} diff --git a/controller/api/public/http_server_test.go b/controller/api/public/http_server_test.go new file mode 100644 index 0000000000000..f0567a8b9d827 --- /dev/null +++ b/controller/api/public/http_server_test.go @@ -0,0 +1,207 @@ +package public + +import ( + "context" + "errors" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/gogo/protobuf/proto" + common "github.com/runconduit/conduit/controller/gen/common" + healcheckPb "github.com/runconduit/conduit/controller/gen/common/healthcheck" + pb "github.com/runconduit/conduit/controller/gen/public" +) + +type mockGrpcServer struct { + LastRequestReceived proto.Message + ResponseToReturn proto.Message + TapStreamsToReturn []*common.TapEvent + ErrorToReturn error +} + +func (m *mockGrpcServer) Stat(ctx context.Context, req *pb.MetricRequest) (*pb.MetricResponse, error) { + m.LastRequestReceived = req + return m.ResponseToReturn.(*pb.MetricResponse), m.ErrorToReturn +} + +func (m *mockGrpcServer) Version(ctx context.Context, req *pb.Empty) (*pb.VersionInfo, error) { + m.LastRequestReceived = req + return m.ResponseToReturn.(*pb.VersionInfo), m.ErrorToReturn +} + +func (m *mockGrpcServer) ListPods(ctx context.Context, req *pb.Empty) (*pb.ListPodsResponse, error) { + m.LastRequestReceived = req + return m.ResponseToReturn.(*pb.ListPodsResponse), m.ErrorToReturn +} + +func (m *mockGrpcServer) SelfCheck(ctx context.Context, req *healcheckPb.SelfCheckRequest) (*healcheckPb.SelfCheckResponse, error) { + m.LastRequestReceived = req + return m.ResponseToReturn.(*healcheckPb.SelfCheckResponse), m.ErrorToReturn +} + +func (m *mockGrpcServer) Tap(req *pb.TapRequest, tapServer pb.Api_TapServer) error { + m.LastRequestReceived = req + if m.ErrorToReturn == nil { + for _, msg := range m.TapStreamsToReturn { + tapServer.Send(msg) + } + } + + return m.ErrorToReturn +} + +type grpcCallTestCase struct { + expectedRequest proto.Message + expectedResponse proto.Message + functionCall func() (proto.Message, error) +} + +func TestServer(t *testing.T) { + + mockGrpcServer := &mockGrpcServer{} + handler := &handler{ + grpcServer: mockGrpcServer, + } + + httpServer := &http.Server{ + Addr: "localhost:8889", + Handler: handler, + } + + go func() { + err := httpServer.ListenAndServe() + if err != nil { + t.Fatalf("Could not start server: %v", err) + } + }() + + client, err := NewInternalClient("localhost:8889") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + t.Run("Delegates all non-streaming RPC messages to the underlying grpc server", func(t *testing.T) { + listPodsReq := &pb.Empty{} + testListPods := grpcCallTestCase{ + expectedRequest: listPodsReq, + expectedResponse: &pb.ListPodsResponse{ + Pods: []*pb.Pod{ + {Status: "ok-ish"}, + }, + }, + functionCall: func() (proto.Message, error) { return client.ListPods(context.TODO(), listPodsReq) }, + } + + statReq := &pb.MetricRequest{ + Summarize: false, + } + seriesToReturn := make([]*pb.MetricSeries, 0) + for i := 0; i < 100; i++ { + seriesToReturn = append(seriesToReturn, &pb.MetricSeries{Name: pb.MetricName_LATENCY, Metadata: &pb.MetricMetadata{Path: fmt.Sprintf("/%d", i)}}) + } + testStat := grpcCallTestCase{ + expectedRequest: statReq, + expectedResponse: &pb.MetricResponse{ + Metrics: seriesToReturn, + }, + functionCall: func() (proto.Message, error) { return client.Stat(context.TODO(), statReq) }, + } + + versionReq := &pb.Empty{} + testVersion := grpcCallTestCase{ + expectedRequest: versionReq, + expectedResponse: &pb.VersionInfo{ + BuildDate: "02/21/1983", + }, + functionCall: func() (proto.Message, error) { return client.Version(context.TODO(), versionReq) }, + } + + selfCheckReq := &healcheckPb.SelfCheckRequest{} + testSelfCheck := grpcCallTestCase{ + expectedRequest: selfCheckReq, + expectedResponse: &healcheckPb.SelfCheckResponse{ + Results: []*healcheckPb.CheckResult{ + { + SubsystemName: "banana", + }, + }, + }, + functionCall: func() (proto.Message, error) { return client.SelfCheck(context.TODO(), selfCheckReq) }, + } + + for _, testCase := range []grpcCallTestCase{testListPods, testStat, testVersion, testSelfCheck} { + assertCallWasForwarded(t, mockGrpcServer, testCase.expectedRequest, testCase.expectedResponse, testCase.functionCall) + } + }) + + t.Run("Delegates all streaming tap RPC messages to the underlying grpc server", func(t *testing.T) { + expectedTapResponses := []*common.TapEvent{ + { + Target: &common.TcpAddress{ + Port: 9999, + }, + Source: &common.TcpAddress{ + Port: 6666, + }, + }, { + Target: &common.TcpAddress{ + Port: 2102, + }, + Source: &common.TcpAddress{ + Port: 1983, + }, + }, + } + mockGrpcServer.TapStreamsToReturn = expectedTapResponses + mockGrpcServer.ErrorToReturn = nil + + tapClient, err := client.Tap(context.TODO(), &pb.TapRequest{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for _, expectedTapEvent := range expectedTapResponses { + actualTapEvent, err := tapClient.Recv() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(actualTapEvent, expectedTapEvent) { + t.Fatalf("Expecting tap event to be [%v], but was [%v]", expectedTapEvent, actualTapEvent) + } + } + }) + + t.Run("Handles errors before opening keep-alive response", func(t *testing.T) { + mockGrpcServer.ErrorToReturn = errors.New("expected error") + + _, err := client.Tap(context.TODO(), &pb.TapRequest{}) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + }) +} + +func assertCallWasForwarded(t *testing.T, mockGrpcServer *mockGrpcServer, expectedRequest proto.Message, expectedResponse proto.Message, functionCall func() (proto.Message, error)) { + mockGrpcServer.ErrorToReturn = nil + mockGrpcServer.ResponseToReturn = expectedResponse + actualResponse, err := functionCall() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + actualRequest := mockGrpcServer.LastRequestReceived + if !reflect.DeepEqual(actualRequest, expectedRequest) { + t.Fatalf("Expecting server call to return [%v], but got [%v]", expectedRequest, actualRequest) + } + if !reflect.DeepEqual(actualResponse, expectedResponse) { + t.Fatalf("Expecting server call to return [%v], but got [%v]", expectedResponse, actualResponse) + } + + mockGrpcServer.ErrorToReturn = errors.New("expected") + actualResponse, err = functionCall() + if err == nil { + t.Fatalf("Expecting error, got nothing") + } +} diff --git a/controller/api/public/proto_over_http.go b/controller/api/public/proto_over_http.go new file mode 100644 index 0000000000000..77ce5d1d59a64 --- /dev/null +++ b/controller/api/public/proto_over_http.go @@ -0,0 +1,150 @@ +package public + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/golang/protobuf/proto" + pb "github.com/runconduit/conduit/controller/gen/public" + log "github.com/sirupsen/logrus" + "google.golang.org/grpc/status" +) + +const ( + errorHeader = "conduit-error" + defaultHttpErrorStatusCode = http.StatusInternalServerError + contentTypeHeader = "Content-Type" + protobufContentType = "application/octet-stream" + numBytesForMessageLength = 4 +) + +type httpError struct { + Code int + WrappedError error +} + +type flushableResponseWriter interface { + http.ResponseWriter + http.Flusher +} + +func (e httpError) Error() string { + return fmt.Sprintf("HTTP error, status Code [%d], wrapped error is: %v", e.Code, e.WrappedError) +} + +func httpRequestToProto(req *http.Request, protoRequestOut proto.Message) error { + bytes, err := ioutil.ReadAll(req.Body) + if err != nil { + return httpError{ + Code: http.StatusBadRequest, + WrappedError: err, + } + } + + err = proto.Unmarshal(bytes, protoRequestOut) + if err != nil { + return httpError{ + Code: http.StatusBadRequest, + WrappedError: err, + } + } + + return nil +} + +func writeErrorToHttpResponse(w http.ResponseWriter, errorObtained error) { + statusCode := defaultHttpErrorStatusCode + errorToReturn := errorObtained + + if httpErr, ok := errorObtained.(httpError); ok { + statusCode = httpErr.Code + errorToReturn = httpErr.WrappedError + } + + w.Header().Set(errorHeader, http.StatusText(statusCode)) + + errorMessageToReturn := errorToReturn.Error() + if grpcError, ok := status.FromError(errorObtained); ok { + errorMessageToReturn = grpcError.Message() + } + + errorAsProto := &pb.ApiError{Error: errorMessageToReturn} + + err := writeProtoToHttpResponse(w, errorAsProto) + if err != nil { + log.Errorf("Error writing error to http response: %v", err) + w.Header().Set(errorHeader, err.Error()) + } +} + +func writeProtoToHttpResponse(w http.ResponseWriter, msg proto.Message) error { + w.Header().Set(contentTypeHeader, protobufContentType) + marshalledProtobufMessage, err := proto.Marshal(msg) + if err != nil { + return err + } + + fullPayload, err := serializeAsPayload(marshalledProtobufMessage) + if err != nil { + return err + } + _, err = w.Write(fullPayload) + return err +} + +func newStreamingWriter(w http.ResponseWriter) (flushableResponseWriter, error) { + flushableWriter, ok := w.(flushableResponseWriter) + if !ok { + return nil, fmt.Errorf("streaming not supported by this writer") + } + + flushableWriter.Header().Set("Connection", "keep-alive") + flushableWriter.Header().Set("Transfer-Encoding", "chunked") + return flushableWriter, nil +} + +func serializeAsPayload(messageContentsInBytes []byte) ([]byte, error) { + lengthOfThePayload := uint32(len(messageContentsInBytes)) + + messageLengthInBytes := make([]byte, numBytesForMessageLength) + binary.LittleEndian.PutUint32(messageLengthInBytes, lengthOfThePayload) + + return append(messageLengthInBytes, messageContentsInBytes...), nil +} + +func deserializePayloadFromReader(reader *bufio.Reader) ([]byte, error) { + messageLengthAsBytes := make([]byte, numBytesForMessageLength) + reader.Read(messageLengthAsBytes) + messageLength := int(binary.LittleEndian.Uint32(messageLengthAsBytes)) + + messageContentsAsBytes := make([]byte, messageLength) + _, err := io.ReadFull(reader, messageContentsAsBytes) + if err != nil { + return nil, fmt.Errorf("error while reading bytes from message: %v", err) + } + + return messageContentsAsBytes, nil +} + +func checkIfResponseHasConduitError(rsp *http.Response) error { + errorMsg := rsp.Header.Get(errorHeader) + + if errorMsg != "" { + reader := bufio.NewReader(rsp.Body) + var apiError pb.ApiError + + err := fromByteStreamToProtocolBuffers(reader, &apiError) + if err != nil { + return fmt.Errorf("response is Conduit error header [%s], but response body didn't contain protobuf error: %v", errorMsg, err) + } + + return errors.New(apiError.Error) + } + + return nil +} diff --git a/controller/api/public/proto_over_http_test.go b/controller/api/public/proto_over_http_test.go new file mode 100644 index 0000000000000..df2187a7d7f44 --- /dev/null +++ b/controller/api/public/proto_over_http_test.go @@ -0,0 +1,498 @@ +package public + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io/ioutil" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/gogo/protobuf/proto" + pb "github.com/runconduit/conduit/controller/gen/public" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type stubResponseWriter struct { + body *bytes.Buffer + headers http.Header +} + +func (w *stubResponseWriter) Header() http.Header { + return w.headers +} + +func (w *stubResponseWriter) Write(p []byte) (int, error) { + n, err := w.body.Write(p) + fmt.Print(n) + return n, err +} + +func (w *stubResponseWriter) WriteHeader(int) {} + +func (w *stubResponseWriter) Flush() {} + +type nonStreamingResponseWriter struct { +} + +func (w *nonStreamingResponseWriter) Header() http.Header { return nil } + +func (w *nonStreamingResponseWriter) Write(p []byte) (int, error) { return -1, nil } + +func (w *nonStreamingResponseWriter) WriteHeader(int) {} + +func newStubResponseWriter() *stubResponseWriter { + return &stubResponseWriter{ + headers: make(http.Header), + body: bytes.NewBufferString(""), + } +} + +func TestHttpRequestToProto(t *testing.T) { + someUrl := "https://www.example.org/something" + someMethod := http.MethodPost + + t.Run("Given a valid request, serializes its contents into protobuf object", func(t *testing.T) { + expectedProtoMessage := pb.Pod{ + Name: "some-name", + PodIP: "some-name", + Deployment: "some-name", + Status: "some-name", + Added: false, + ControllerNamespace: "some-name", + ControlPlane: false, + } + payload, err := proto.Marshal(&expectedProtoMessage) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + req, err := http.NewRequest(someMethod, someUrl, bytes.NewReader(payload)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var actualProtoMessage pb.Pod + err = httpRequestToProto(req, &actualProtoMessage) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if actualProtoMessage != expectedProtoMessage { + t.Fatalf("Expected request to be [%v], but got [%v]", actualProtoMessage, expectedProtoMessage) + } + }) + + t.Run("Given a broken request, returns http error", func(t *testing.T) { + var actualProtoMessage pb.Pod + + req, err := http.NewRequest(someMethod, someUrl, strings.NewReader("not really protobuf")) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + err = httpRequestToProto(req, &actualProtoMessage) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + + if httpErr, ok := err.(httpError); ok { + expectedStatusCode := http.StatusBadRequest + if httpErr.Code != expectedStatusCode || httpErr.WrappedError == nil { + t.Fatalf("Expected error status to be [%d] and contain wrapper error, got status [%d] and error [%v]", expectedStatusCode, httpErr.Code, httpErr.WrappedError) + } + } else { + t.Fatalf("Expected error to be httpError, got: %v", err) + } + }) +} + +func TestWriteErrorToHttpResponse(t *testing.T) { + t.Run("Writes generic error correctly to response", func(t *testing.T) { + expectedErrorStatusCode := defaultHttpErrorStatusCode + + responseWriter := newStubResponseWriter() + genericError := errors.New("expected generic error") + + writeErrorToHttpResponse(responseWriter, genericError) + + assertResponseHasProtobufContentType(t, responseWriter) + + actualErrorStatusCode := responseWriter.headers.Get(errorHeader) + if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) { + t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode) + } + + payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes()))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expectedErrorPayload := pb.ApiError{Error: genericError.Error()} + var actualErrorPayload pb.ApiError + err = proto.Unmarshal(payloadRead, &actualErrorPayload) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if actualErrorPayload != expectedErrorPayload { + t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload) + } + }) + + t.Run("Writes http specific error correctly to response", func(t *testing.T) { + expectedErrorStatusCode := http.StatusBadGateway + responseWriter := newStubResponseWriter() + httpError := httpError{ + WrappedError: errors.New("expected to be wrapped"), + Code: http.StatusBadGateway, + } + + writeErrorToHttpResponse(responseWriter, httpError) + + assertResponseHasProtobufContentType(t, responseWriter) + + actualErrorStatusCode := responseWriter.headers.Get(errorHeader) + if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) { + t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode) + } + + payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes()))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expectedErrorPayload := pb.ApiError{Error: httpError.WrappedError.Error()} + var actualErrorPayload pb.ApiError + err = proto.Unmarshal(payloadRead, &actualErrorPayload) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if actualErrorPayload != expectedErrorPayload { + t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload) + } + }) + + t.Run("Writes gRPC specific error correctly to response", func(t *testing.T) { + expectedErrorStatusCode := defaultHttpErrorStatusCode + + responseWriter := newStubResponseWriter() + expectedErrorMessage := "error message" + grpcError := status.Errorf(codes.AlreadyExists, expectedErrorMessage) + + writeErrorToHttpResponse(responseWriter, grpcError) + + assertResponseHasProtobufContentType(t, responseWriter) + + actualErrorStatusCode := responseWriter.headers.Get(errorHeader) + if actualErrorStatusCode != http.StatusText(expectedErrorStatusCode) { + t.Fatalf("Expecting response to have status code [%d], got [%s]", expectedErrorStatusCode, actualErrorStatusCode) + } + + payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes()))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + expectedErrorPayload := pb.ApiError{Error: expectedErrorMessage} + var actualErrorPayload pb.ApiError + err = proto.Unmarshal(payloadRead, &actualErrorPayload) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(actualErrorPayload, expectedErrorPayload) { + t.Fatalf("Expecting error to be serialized as [%v], but got [%v]", expectedErrorPayload, actualErrorPayload) + } + }) +} + +func TestWriteProtoToHttpResponse(t *testing.T) { + t.Run("Writes valid payload", func(t *testing.T) { + expectedMessage := pb.VersionInfo{ + ReleaseVersion: "0.0.1", + BuildDate: "02/21/1983", + GoVersion: "10.2.45", + } + + responseWriter := newStubResponseWriter() + err := writeProtoToHttpResponse(responseWriter, &expectedMessage) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + assertResponseHasProtobufContentType(t, responseWriter) + + payloadRead, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(responseWriter.body.Bytes()))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + var actualMessage pb.VersionInfo + err = proto.Unmarshal(payloadRead, &actualMessage) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if expectedMessage != actualMessage { + t.Fatalf("Expected response body to contain message [%v], but got [%v]", expectedMessage, actualMessage) + } + }) +} + +func TestDeserializePayloadFromReader(t *testing.T) { + t.Run("Can read message correctly based on payload size correct payload size to message", func(t *testing.T) { + expectedMessage := "this is the message" + + messageWithSize, err := serializeAsPayload([]byte(expectedMessage)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + messageWithSomeNoise := append(messageWithSize, []byte("this is noise and should not be read")...) + + actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageWithSomeNoise))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if string(actualMessage) != expectedMessage { + t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage) + } + }) + + t.Run("Can multiple messages in the same stream", func(t *testing.T) { + expectedMessage1 := "Hit the road, Jack and don't you come back\n" + for i := 0; i < 450; i++ { + expectedMessage1 = expectedMessage1 + fmt.Sprintf("no more (%d), ", i) + } + + expectedMessage2 := "back street back, alright\n" + for i := 0; i < 450; i++ { + expectedMessage2 = expectedMessage2 + fmt.Sprintf("tum (%d), ", i) + } + + messageWithSize1, err := serializeAsPayload([]byte(expectedMessage1)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + messageWithSize2, err := serializeAsPayload([]byte(expectedMessage2)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + streamWithManyMessages := append(messageWithSize1, messageWithSize2...) + reader := bufio.NewReader(bytes.NewReader(streamWithManyMessages)) + + actualMessage1, err := deserializePayloadFromReader(reader) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + actualMessage2, err := deserializePayloadFromReader(reader) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if string(actualMessage1) != expectedMessage1 { + t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage1, actualMessage1) + } + + if string(actualMessage2) != expectedMessage2 { + t.Fatalf("Expecting payload to contain message:\n%s\nbut it had\n%s", expectedMessage2, actualMessage2) + } + }) + + t.Run("Can write and read marshalled protobuf messages", func(t *testing.T) { + seriesToReturn := make([]*pb.MetricSeries, 0) + for i := 0; i < 351; i++ { + seriesToReturn = append(seriesToReturn, &pb.MetricSeries{Name: pb.MetricName_LATENCY}) + } + + expectedMessage := &pb.MetricResponse{ + Metrics: seriesToReturn, + } + + expectedReadArray, err := proto.Marshal(expectedMessage) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + serialized, err := serializeAsPayload(expectedReadArray) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + reader := bufio.NewReader(bytes.NewReader(serialized)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + actualReadArray, err := deserializePayloadFromReader(reader) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(actualReadArray, expectedReadArray) { + n := len(actualReadArray) + xor := make([]byte, n) + for i := 0; i < n; i++ { + xor[i] = actualReadArray[i] ^ expectedReadArray[i] + } + t.Fatalf("Expecting read byte array to be equal to written byte array, but they were different. xor: [%v]", xor) + } + + actualMessage := &pb.MetricResponse{} + err = proto.Unmarshal(actualReadArray, actualMessage) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if !reflect.DeepEqual(actualMessage, expectedMessage) { + t.Fatalf("Expecting payload to contain message [%s], but it had [%s]", expectedMessage, actualMessage) + } + }) + + t.Run("Can read byte streams larger than Go's default buffer chunk size", func(t *testing.T) { + goDefaultChunkSize := 4000 + expectedMessage := "Hit the road, Jack and don't you come back\n" + for i := 0; i < 450; i++ { + expectedMessage = expectedMessage + fmt.Sprintf("no more (%d), ", i) + } + + expectedMessageAsBytes := []byte(expectedMessage) + lengthOfInputData := len(expectedMessageAsBytes) + + if lengthOfInputData < goDefaultChunkSize { + t.Fatalf("Test needs data larger than [%d] bytes, currently only [%d] bytes", goDefaultChunkSize, lengthOfInputData) + } + + payload, err := serializeAsPayload(expectedMessageAsBytes) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + actualMessage, err := deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(payload))) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if string(actualMessage) != expectedMessage { + t.Fatalf("Expecting payload to contain message:\n%s\n, but it had\n%s", expectedMessageAsBytes, actualMessage) + } + }) + + t.Run("Returns error when message has fewer bytes than declared message size", func(t *testing.T) { + expectedMessage := "this is the message" + + messageWithSize, err := serializeAsPayload([]byte(expectedMessage)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + messageMissingOneCharacter := messageWithSize[:len(expectedMessage)-1] + _, err = deserializePayloadFromReader(bufio.NewReader(bytes.NewReader(messageMissingOneCharacter))) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + }) +} + +func TestNewStreamingWriter(t *testing.T) { + t.Run("Returns a streaming writer if the ResponseWriter is compatible with streaming", func(t *testing.T) { + rawWriter := newStubResponseWriter() + flushableWriter, err := newStreamingWriter(rawWriter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if flushableWriter != rawWriter { + t.Fatalf("Expected to return same instance of writer") + } + + header := "Connection" + expectedValue := "keep-alive" + actualValue := rawWriter.Header().Get(header) + if actualValue != expectedValue { + t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue) + } + + header = "Transfer-Encoding" + expectedValue = "chunked" + actualValue = rawWriter.Header().Get(header) + if actualValue != expectedValue { + t.Fatalf("Expected header [%s] to be set to [%s], but was [%s]", header, expectedValue, actualValue) + } + }) + + t.Run("Returns an error if writer doesnt support streaming", func(t *testing.T) { + _, err := newStreamingWriter(&nonStreamingResponseWriter{}) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + }) +} + +func TestCheckIfResponseHasError(t *testing.T) { + t.Run("returns nil if response doesnt contain Conduit error", func(t *testing.T) { + response := &http.Response{ + Header: make(http.Header), + } + err := checkIfResponseHasConduitError(response) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + }) + + t.Run("returns error in body if response contains Conduit error", func(t *testing.T) { + expectedErrorMessage := "expected error message" + protoInBytes, err := proto.Marshal(&pb.ApiError{Error: expectedErrorMessage}) + message, err := serializeAsPayload(protoInBytes) + response := &http.Response{ + Header: make(http.Header), + Body: ioutil.NopCloser(bytes.NewReader(message)), + } + response.Header.Set(errorHeader, "error") + + err = checkIfResponseHasConduitError(response) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + + actualErrorMessage := err.Error() + if actualErrorMessage != expectedErrorMessage { + t.Fatalf("Expected error message to be [%s], but it was [%s]", expectedErrorMessage, actualErrorMessage) + } + }) + + t.Run("returns error if response contains Conduit error but body isn't error message", func(t *testing.T) { + protoInBytes, err := proto.Marshal(&pb.MetricMetadata{Path: "a"}) + message, err := serializeAsPayload(protoInBytes) + response := &http.Response{ + Header: make(http.Header), + Body: ioutil.NopCloser(bytes.NewReader(message)), + } + response.Header.Set(errorHeader, "error") + + err = checkIfResponseHasConduitError(response) + if err == nil { + t.Fatalf("Expecting error, got nothing") + } + }) +} + +func assertResponseHasProtobufContentType(t *testing.T, responseWriter *stubResponseWriter) { + actualContentType := responseWriter.headers.Get(contentTypeHeader) + expectedContentType := protobufContentType + if actualContentType != expectedContentType { + t.Fatalf("Expected content-type to be [%s], but got [%s]", expectedContentType, actualContentType) + } +} diff --git a/controller/api/public/server.go b/controller/api/public/server.go deleted file mode 100644 index 4a36ce0579202..0000000000000 --- a/controller/api/public/server.go +++ /dev/null @@ -1,273 +0,0 @@ -package public - -import ( - "encoding/binary" - "fmt" - "io/ioutil" - "net/http" - - "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/proto" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - common "github.com/runconduit/conduit/controller/gen/common" - healthcheckPb "github.com/runconduit/conduit/controller/gen/common/healthcheck" - tapPb "github.com/runconduit/conduit/controller/gen/controller/tap" - telemPb "github.com/runconduit/conduit/controller/gen/controller/telemetry" - pb "github.com/runconduit/conduit/controller/gen/public" - "golang.org/x/net/context" - "google.golang.org/grpc/metadata" -) - -type ( - handler struct { - grpcServer pb.ApiServer - } - - tapServer struct { - w http.ResponseWriter - req *http.Request - } -) - -var ( - jsonMarshaler = jsonpb.Marshaler{EmitDefaults: true} - jsonUnmarshaler = jsonpb.Unmarshaler{} - statPath = fullUrlPathFor("Stat") - versionPath = fullUrlPathFor("Version") - listPodsPath = fullUrlPathFor("ListPods") - tapPath = fullUrlPathFor("Tap") - selfCheckPath = fullUrlPathFor("SelfCheck") -) - -func NewServer(addr string, telemetryClient telemPb.TelemetryClient, tapClient tapPb.TapClient) *http.Server { - var baseHandler http.Handler - counter := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "http_requests_total", - Help: "A counter for requests to the wrapped handler.", - }, - []string{"code", "method"}, - ) - prometheus.MustRegister(counter) - - baseHandler = &handler{ - grpcServer: newGrpcServer(telemetryClient, tapClient), - } - instrumentedHandler := promhttp.InstrumentHandlerCounter(counter, baseHandler) - - return &http.Server{ - Addr: addr, - Handler: instrumentedHandler, - } -} - -func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - // Validate request method - if req.Method != http.MethodPost { - serverMarshalError(w, req, fmt.Errorf("POST required"), http.StatusMethodNotAllowed) - return - } - - // Validate request content type - switch req.Header.Get("Content-Type") { - case "", ProtobufContentType, JsonContentType: - default: - serverMarshalError(w, req, fmt.Errorf("unsupported Content-Type"), http.StatusUnsupportedMediaType) - return - } - - // Serve request - switch req.URL.Path { - case statPath: - h.handleStat(w, req) - case versionPath: - h.handleVersion(w, req) - case listPodsPath: - h.handleListPods(w, req) - case tapPath: - h.handleTap(w, req) - case selfCheckPath: - h.handleSelfCheck(w, req) - default: - http.NotFound(w, req) - } -} - -func (h *handler) handleStat(w http.ResponseWriter, req *http.Request) { - var metricRequest pb.MetricRequest - err := serverUnmarshal(req, &metricRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusBadRequest) - return - } - - rsp, err := h.grpcServer.Stat(req.Context(), &metricRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } - - err = serverMarshal(w, req, rsp) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } -} - -func (h *handler) handleVersion(w http.ResponseWriter, req *http.Request) { - var emptyRequest pb.Empty - err := serverUnmarshal(req, &emptyRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusBadRequest) - return - } - - rsp, err := h.grpcServer.Version(req.Context(), &emptyRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } - - err = serverMarshal(w, req, rsp) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } -} - -func (h *handler) handleSelfCheck(w http.ResponseWriter, req *http.Request) { - var selfCheckRequest healthcheckPb.SelfCheckRequest - err := serverUnmarshal(req, &selfCheckRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusBadRequest) - return - } - - rsp, err := h.grpcServer.SelfCheck(req.Context(), &selfCheckRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } - - err = serverMarshal(w, req, rsp) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } -} - -func (h *handler) handleListPods(w http.ResponseWriter, req *http.Request) { - var emptyRequest pb.Empty - err := serverUnmarshal(req, &emptyRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusBadRequest) - return - } - - rsp, err := h.grpcServer.ListPods(req.Context(), &emptyRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } - - err = serverMarshal(w, req, rsp) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } -} - -func (h *handler) handleTap(w http.ResponseWriter, req *http.Request) { - var tapRequest pb.TapRequest - err := serverUnmarshal(req, &tapRequest) - if err != nil { - serverMarshalError(w, req, err, http.StatusBadRequest) - return - } - - if _, ok := w.(http.Flusher); !ok { - serverMarshalError(w, req, fmt.Errorf("streaming not supported"), http.StatusBadRequest) - return - } - - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - - server := tapServer{w: w, req: req} - err = h.grpcServer.Tap(&tapRequest, server) - if err != nil { - serverMarshalError(w, req, err, http.StatusInternalServerError) - return - } -} - -func serverUnmarshal(req *http.Request, msg proto.Message) error { - switch req.Header.Get("Content-Type") { - case "", ProtobufContentType: - bytes, err := ioutil.ReadAll(req.Body) - if err != nil { - return err - } - return proto.Unmarshal(bytes, msg) - case JsonContentType: - return jsonUnmarshaler.Unmarshal(req.Body, msg) - } - return nil -} - -func serverMarshal(w http.ResponseWriter, req *http.Request, msg proto.Message) error { - switch req.Header.Get("Content-Type") { - case "", ProtobufContentType: - bytes, err := proto.Marshal(msg) - if err != nil { - return err - } - byteSize := make([]byte, 4) - binary.LittleEndian.PutUint32(byteSize, uint32(len(bytes))) - _, err = w.Write(append(byteSize, bytes...)) - return err - - case JsonContentType: - str, err := jsonMarshaler.MarshalToString(msg) - if err != nil { - return err - } - _, err = w.Write(append([]byte(str), '\n')) - return err - } - - return nil -} - -func serverMarshalError(w http.ResponseWriter, req *http.Request, err error, code int) error { - switch req.Header.Get("Content-Type") { - case "", ProtobufContentType: - w.Header().Set(ErrorHeader, http.StatusText(code)) - case JsonContentType: - w.WriteHeader(code) - } - - return serverMarshal(w, req, &pb.ApiError{Error: err.Error()}) -} - -func (s tapServer) Send(msg *common.TapEvent) error { - err := serverMarshal(s.w, s.req, msg) - if err != nil { - return err - } - s.w.(http.Flusher).Flush() - return nil -} - -// satisfy the pb.Api_TapServer interface -func (s tapServer) SetHeader(metadata.MD) error { return nil } -func (s tapServer) SendHeader(metadata.MD) error { return nil } -func (s tapServer) SetTrailer(metadata.MD) { return } -func (s tapServer) Context() context.Context { return s.req.Context() } -func (s tapServer) SendMsg(interface{}) error { return nil } -func (s tapServer) RecvMsg(interface{}) error { return nil } - -func fullUrlPathFor(method string) string { - return ApiRoot + ApiPrefix + method -} diff --git a/controller/k8s/pods.go b/controller/k8s/pods.go index c0ca2d19eec7e..3dcef926d2caa 100644 --- a/controller/k8s/pods.go +++ b/controller/k8s/pods.go @@ -59,7 +59,7 @@ func (p *PodIndex) GetPod(key string) (*v1.Pod, error) { return nil, err } if !exists { - return nil, fmt.Errorf("No pod exists for key %s", key) + return nil, fmt.Errorf("no pod exists for key %s", key) } pod, ok := item.(*v1.Pod) if !ok { diff --git a/controller/k8s/replicasets.go b/controller/k8s/replicasets.go index a6cf97c4aa302..e2417aec0c375 100644 --- a/controller/k8s/replicasets.go +++ b/controller/k8s/replicasets.go @@ -61,7 +61,7 @@ func (p *ReplicaSetStore) GetReplicaSet(key string) (*v1beta1.ReplicaSet, error) return nil, err } if !exists { - return nil, fmt.Errorf("No ReplicaSet exists for name %s", key) + return nil, fmt.Errorf("no ReplicaSet exists for name %s", key) } rs, ok := item.(*v1beta1.ReplicaSet) if !ok { diff --git a/controller/tap/server.go b/controller/tap/server.go index e28b0114e850a..fe6f340b00948 100644 --- a/controller/tap/server.go +++ b/controller/tap/server.go @@ -16,6 +16,8 @@ import ( "github.com/runconduit/conduit/controller/util" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "k8s.io/api/core/v1" ) @@ -41,7 +43,7 @@ func (s *server) Tap(req *public.TapRequest, stream pb.Tap_TapServer) error { targetName = target.Pod pod, err := s.pods.GetPod(target.Pod) if err != nil { - return err + return status.Errorf(codes.NotFound, err.Error()) } pods = []*v1.Pod{pod} case *public.TapRequest_Deployment: @@ -334,7 +336,7 @@ func NewServer(addr string, tapPort uint, kubeconfig string) (*grpc.Server, net. deploymentIndex := func(obj interface{}) ([]string, error) { pod, ok := obj.(*v1.Pod) if !ok { - return nil, fmt.Errorf("Object is not a Pod") + return nil, fmt.Errorf("object is not a Pod") } deployment, err := replicaSets.GetDeploymentForPod(pod) if err != nil {