Skip to content

Commit

Permalink
refactor: add missing test cases
Browse files Browse the repository at this point in the history
* Validate that the trace middleware adds the expected traces
* Validate that certificate hashes with both '/' and '+'
  characters are converted correctly
  • Loading branch information
subnova committed Feb 19, 2024
1 parent 1be65fb commit 1d82add
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .talismanrc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ fileignoreconfig:
checksum: bcdef78dfc66e140acd32d12df1806b95c3541f51cc0208a43abff49552fdcd8
- filename: gateway/registry/remote_test.go
checksum: f5aa4dbb5e14d772613612eeb02df83ae4458875487e3c408eff2950b460c298
- filename: gateway/server/tracing.go
checksum: 10c205849723d591f5c90fbf0068fa5cf77b8e545e351bc538610b950bc18c3f
- filename: gateway/server/ws.go
checksum: 4cbde936242380603e07cf8bd049dbca9d1c3108843d10e58f588540176c6d23
- filename: gateway/server/ws_test.go
Expand Down
42 changes: 42 additions & 0 deletions gateway/registry/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ func TestLookupCertificate(t *testing.T) {
assert.Equal(t, want.Raw, got.Raw)
}

func TestLookupCertificateWithSlashesAndPlusesInHash(t *testing.T) {
var want *x509.Certificate
var certHash [32]byte
var b64CertHash string

count := 0
for {
count++
want = generateCertificate(t)
certHash = sha256.Sum256(want.Raw)
b64CertHash = base64.StdEncoding.EncodeToString(certHash[:])

if strings.Contains(b64CertHash, "/") && strings.Contains(b64CertHash, "+") {
break
}
}
t.Logf("Generated %d certificates before finding one with slashes and pluses in the hash", count)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hashToMatch := strings.Replace(b64CertHash, "/", "_", -1)
hashToMatch = strings.Replace(hashToMatch, "+", "-", -1)
if r.URL.Path != fmt.Sprintf("/api/v0/certificate/%s", hashToMatch) {
http.NotFound(w, r)
return
}
block := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: want.Raw})
blockWithNewlinesReplaced := strings.Replace(string(block), "\n", "\\n", -1)
_, _ = w.Write([]byte(fmt.Sprintf(`{"certificate":"%s"}`, blockWithNewlinesReplaced)))
}))
defer server.Close()

reg := registry.RemoteRegistry{
ManagerApiAddr: server.URL,
}

got, err := reg.LookupCertificate(b64CertHash)
require.NoError(t, err)
require.NotNil(t, got)

assert.Equal(t, want.Raw, got.Raw)
}

func generateCertificate(t *testing.T) *x509.Certificate {
keyPair, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
Expand Down
11 changes: 6 additions & 5 deletions gateway/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ import (
func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Info("websocket connection received", "path", r.URL.Path, "method", r.Method)
slog.Info("processing connection", "uri", r.RequestURI)

newCtx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer),
newCtx, span := tracer.Start(r.Context(),
fmt.Sprintf("%s %s", r.Method, r.URL.String()), trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
semconv.HTTPScheme(getScheme(r)),
semconv.HTTPMethod(r.Method),
Expand All @@ -35,11 +33,11 @@ func TraceRequest(tracer trace.Tracer) func(http.Handler) http.Handler {
routePattern := chi.RouteContext(r.Context()).RoutePattern()
if routePattern != "" {
span.SetName(fmt.Sprintf("%s %s", r.Method, routePattern))
span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern()))
} else {
span.SetStatus(codes.Error, "not found")
span.SetAttributes(semconv.HTTPStatusCode(http.StatusNotFound))
}
span.SetAttributes(semconv.HTTPRoute(chi.RouteContext(r.Context()).RoutePattern()))
})
}
}
Expand Down Expand Up @@ -77,6 +75,9 @@ func TLSOffload(registry registry.DeviceRegistry) func(http.Handler) http.Handle
span.SetAttributes(attribute.String("cert.lookup.error", "NotFound"))
slog.Warn("certificate not found", "clientCertHashHeader", clientCertHashHeader)
}
} else {
clientCertErrHeader := r.Header.Get("X-Client-Cert-Error")
span.SetAttributes(attribute.String("cert.valid.error", clientCertErrHeader))
}
}
}
Expand Down
42 changes: 42 additions & 0 deletions gateway/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,48 @@ import (
"testing"
)

func TestTraceMatchedRequest(t *testing.T) {
tracer, traceExporter := server.GetTracer()

r := chi.NewRouter()
r.Use(server.TraceRequest(tracer))
r.HandleFunc("/id/{id}", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

req := httptest.NewRequest(http.MethodGet, "/id/1234", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /id/{id}", map[string]any{
"http.scheme": "ws",
"http.method": "GET",
"http.url": "/id/1234",
"http.route": "/id/{id}",
})
}

func TestTraceUnmatchedRequest(t *testing.T) {
tracer, traceExporter := server.GetTracer()

r := chi.NewRouter()
r.Use(server.TraceRequest(tracer))
r.HandleFunc("/something", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

req := httptest.NewRequest(http.MethodGet, "/other", nil)
rr := httptest.NewRecorder()
r.ServeHTTP(rr, req)

server.AssertSpan(t, &traceExporter.GetSpans()[0], "GET /other", map[string]any{
"http.scheme": "ws",
"http.method": "GET",
"http.url": "/other",
"http.status_code": http.StatusNotFound,
})
}

func TestTLSOffloadWithNoClientCert(t *testing.T) {
r := chi.NewRouter()
r.Use(server.TLSOffload(registry.NewMockRegistry()))
Expand Down
54 changes: 54 additions & 0 deletions gateway/server/tracing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// SPDX-License-Identifier: Apache-2.0

package server

import (
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel"
tracesdk "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"go.opentelemetry.io/otel/trace"
"golang.org/x/exp/maps"
"sort"
"testing"
)

func GetTracer() (trace.Tracer, *tracetest.InMemoryExporter) {
traceExporter := tracetest.NewInMemoryExporter()
tracerProvider := tracesdk.NewTracerProvider(
tracesdk.WithSampler(tracesdk.AlwaysSample()),
tracesdk.WithSyncer(traceExporter),
)
otel.SetTracerProvider(tracerProvider)

return tracerProvider.Tracer("test"), traceExporter
}

func AssertSpan(t *testing.T, span *tracetest.SpanStub, name string, attributes map[string]any) {
assert.Equal(t, name, span.Name)
assert.Len(t, span.Attributes, len(attributes))
var gotKeys []string
for _, attr := range span.Attributes {
gotKeys = append(gotKeys, string(attr.Key))
want, ok := attributes[string(attr.Key)]
if !ok {
t.Errorf("unexpected attribute %s", attr.Key)
}
switch want.(type) {
case string:
assert.Equal(t, want, attr.Value.AsString())
case int:
assert.Equal(t, want, int(attr.Value.AsInt64()))
case bool:
assert.Equal(t, want, attr.Value.AsBool())
case float64:
assert.Equal(t, want, attr.Value.AsFloat64())
default:
t.Errorf("unsupported attribute type %T", want)
}
}
sort.Strings(gotKeys)
wantKeys := maps.Keys(attributes)
sort.Strings(wantKeys)
assert.Equal(t, wantKeys, gotKeys)
}

0 comments on commit 1d82add

Please sign in to comment.