From 1d82addb7c888c97a675275be4b6d23470faccfe Mon Sep 17 00:00:00 2001 From: Dale Peakall Date: Mon, 19 Feb 2024 15:49:28 +0000 Subject: [PATCH] refactor: add missing test cases * Validate that the trace middleware adds the expected traces * Validate that certificate hashes with both '/' and '+' characters are converted correctly --- .talismanrc | 2 ++ gateway/registry/remote_test.go | 42 ++++++++++++++++++++++++ gateway/server/middleware.go | 11 ++++--- gateway/server/middleware_test.go | 42 ++++++++++++++++++++++++ gateway/server/tracing.go | 54 +++++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 gateway/server/tracing.go diff --git a/.talismanrc b/.talismanrc index df0122e..ad26bf2 100644 --- a/.talismanrc +++ b/.talismanrc @@ -17,6 +17,8 @@ fileignoreconfig: checksum: bcdef78dfc66e140acd32d12df1806b95c3541f51cc0208a43abff49552fdcd8 - filename: gateway/registry/remote_test.go checksum: f5aa4dbb5e14d772613612eeb02df83ae4458875487e3c408eff2950b460c298 + - filename: gateway/server/tracing.go + checksum: 10c205849723d591f5c90fbf0068fa5cf77b8e545e351bc538610b950bc18c3f - filename: gateway/server/ws.go checksum: 4cbde936242380603e07cf8bd049dbca9d1c3108843d10e58f588540176c6d23 - filename: gateway/server/ws_test.go diff --git a/gateway/registry/remote_test.go b/gateway/registry/remote_test.go index 880cd12..1d2d72f 100644 --- a/gateway/registry/remote_test.go +++ b/gateway/registry/remote_test.go @@ -66,6 +66,48 @@ func TestLookupCertificate(t *testing.T) { assert.Equal(t, want.Raw, got.Raw) } +func TestLookupCertificateWithSlashesAndPlusesInHash(t *testing.T) { + var want *x509.Certificate + var certHash [32]byte + var b64CertHash string + + count := 0 + for { + count++ + want = generateCertificate(t) + certHash = sha256.Sum256(want.Raw) + b64CertHash = base64.StdEncoding.EncodeToString(certHash[:]) + + if strings.Contains(b64CertHash, "/") && strings.Contains(b64CertHash, "+") { + break + } + } + t.Logf("Generated %d certificates before finding one with slashes and pluses in the hash", count) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hashToMatch := strings.Replace(b64CertHash, "/", "_", -1) + hashToMatch = strings.Replace(hashToMatch, "+", "-", -1) + if r.URL.Path != fmt.Sprintf("/api/v0/certificate/%s", hashToMatch) { + http.NotFound(w, r) + return + } + block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: want.Raw}) + blockWithNewlinesReplaced := strings.Replace(string(block), "\n", "\\n", -1) + _, _ = w.Write([]byte(fmt.Sprintf(`{"certificate":"%s"}`, blockWithNewlinesReplaced))) + })) + defer server.Close() + + reg := registry.RemoteRegistry{ + ManagerApiAddr: server.URL, + } + + got, err := reg.LookupCertificate(b64CertHash) + require.NoError(t, err) + require.NotNil(t, got) + + assert.Equal(t, want.Raw, got.Raw) +} + func generateCertificate(t *testing.T) *x509.Certificate { keyPair, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) diff --git a/gateway/server/middleware.go b/gateway/server/middleware.go index 76bc04d..d114d9f 100644 --- a/gateway/server/middleware.go +++ b/gateway/server/middleware.go @@ -20,10 +20,8 @@ import ( func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slog.Info("websocket connection received", "path", r.URL.Path, "method", r.Method) - slog.Info("processing connection", "uri", r.RequestURI) - - newCtx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer), + newCtx, span := tracer.Start(r.Context(), + fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer), trace.WithAttributes( semconv.HTTPScheme(getScheme(r)), semconv.HTTPMethod(r.Method), @@ -35,11 +33,11 @@ func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler { routePattern := chi.RouteContext(r.Context()).RoutePattern() if routePattern != "" { span.SetName(fmt.Sprintf("%s %s", r.Method, routePattern)) + span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern())) } else { span.SetStatus(codes.Error, "not found") span.SetAttributes(semconv.HTTPStatusCode(http.StatusNotFound)) } - span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern())) }) } } @@ -77,6 +75,9 @@ func TLSOffload(registry registry.DeviceRegistry) func(http.Handler) http.Handle span.SetAttributes(attribute.String("cert.lookup.error", "NotFound")) slog.Warn("certificate not found", "clientCertHashHeader", clientCertHashHeader) } + } else { + clientCertErrHeader := r.Header.Get("X-Client-Cert-Error") + span.SetAttributes(attribute.String("cert.valid.error", clientCertErrHeader)) } } } diff --git a/gateway/server/middleware_test.go b/gateway/server/middleware_test.go index 67f4c9b..bb96322 100644 --- a/gateway/server/middleware_test.go +++ b/gateway/server/middleware_test.go @@ -12,6 +12,48 @@ import ( "testing" ) +func TestTraceMatchedRequest(t *testing.T) { + tracer, traceExporter := server.GetTracer() + + r := chi.NewRouter() + r.Use(server.TraceRequest(tracer)) + r.HandleFunc("/id/{id}", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/id/1234", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /id/{id}", map[string]any{ + "http.scheme": "ws", + "http.method": "GET", + "http.url": "/id/1234", + "http.route": "/id/{id}", + }) +} + +func TestTraceUnmatchedRequest(t *testing.T) { + tracer, traceExporter := server.GetTracer() + + r := chi.NewRouter() + r.Use(server.TraceRequest(tracer)) + r.HandleFunc("/something", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/other", nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /other", map[string]any{ + "http.scheme": "ws", + "http.method": "GET", + "http.url": "/other", + "http.status_code": http.StatusNotFound, + }) +} + func TestTLSOffloadWithNoClientCert(t *testing.T) { r := chi.NewRouter() r.Use(server.TLSOffload(registry.NewMockRegistry())) diff --git a/gateway/server/tracing.go b/gateway/server/tracing.go new file mode 100644 index 0000000..26ac44c --- /dev/null +++ b/gateway/server/tracing.go @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel" + tracesdk "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" + "golang.org/x/exp/maps" + "sort" + "testing" +) + +func GetTracer() (trace.Tracer, *tracetest.InMemoryExporter) { + traceExporter := tracetest.NewInMemoryExporter() + tracerProvider := tracesdk.NewTracerProvider( + tracesdk.WithSampler(tracesdk.AlwaysSample()), + tracesdk.WithSyncer(traceExporter), + ) + otel.SetTracerProvider(tracerProvider) + + return tracerProvider.Tracer("test"), traceExporter +} + +func AssertSpan(t *testing.T, span *tracetest.SpanStub, name string, attributes map[string]any) { + assert.Equal(t, name, span.Name) + assert.Len(t, span.Attributes, len(attributes)) + var gotKeys []string + for _, attr := range span.Attributes { + gotKeys = append(gotKeys, string(attr.Key)) + want, ok := attributes[string(attr.Key)] + if !ok { + t.Errorf("unexpected attribute %s", attr.Key) + } + switch want.(type) { + case string: + assert.Equal(t, want, attr.Value.AsString()) + case int: + assert.Equal(t, want, int(attr.Value.AsInt64())) + case bool: + assert.Equal(t, want, attr.Value.AsBool()) + case float64: + assert.Equal(t, want, attr.Value.AsFloat64()) + default: + t.Errorf("unsupported attribute type %T", want) + } + } + sort.Strings(gotKeys) + wantKeys := maps.Keys(attributes) + sort.Strings(wantKeys) + assert.Equal(t, wantKeys, gotKeys) +}