diff --git a/instrumentation/net/http/otelhttp/handler.go b/instrumentation/net/http/otelhttp/handler.go index 8ef06eddc2c..f3871a8e071 100644 --- a/instrumentation/net/http/otelhttp/handler.go +++ b/instrumentation/net/http/otelhttp/handler.go @@ -137,8 +137,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { span.AddEvent("read", trace.WithAttributes(ReadBytesKey.Int64(n))) } } - bw := bodyWrapper{ReadCloser: r.Body, record: readRecordFunc} - r.Body = &bw + + var bw bodyWrapper + // if request body is nil we don't want to mutate the body as it will affect + // the identity of it in a unforeseeable way because we assert ReadCloser + // fullfills a certain interface and it is indeed nil. + if r.Body != nil { + bw.ReadCloser = r.Body + bw.record = readRecordFunc + r.Body = &bw + } writeRecordFunc := func(int64) {} if h.writeEvent { @@ -172,10 +180,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { setAfterServeAttributes(span, bw.read, rww.written, rww.statusCode, bw.err, rww.err) - // Add request metrics - + // Add metrics labels := append(labeler.Get(), semconv.HTTPServerMetricAttributesFromHTTPRequest(h.operation, r)...) - h.counters[RequestContentLength].Add(ctx, bw.read, labels...) h.counters[ResponseContentLength].Add(ctx, rww.written, labels...) diff --git a/instrumentation/net/http/otelhttp/handler_test.go b/instrumentation/net/http/otelhttp/handler_test.go index e35dfb7dcde..f1c3bae80d4 100644 --- a/instrumentation/net/http/otelhttp/handler_test.go +++ b/instrumentation/net/http/otelhttp/handler_test.go @@ -176,3 +176,29 @@ func TestResponseWriterOptionalInterfaces(t *testing.T) { t.Fatal("http.Flusher interface not exposed") } } + +// This use case is important as we make sure the body isn't mutated +// when it is nil. This is a common use case for tests where the request +// is directly passed to the handler. +func TestHandlerReadingNilBodySuccess(t *testing.T) { + rr := httptest.NewRecorder() + + provider := oteltest.NewTracerProvider() + + h := NewHandler( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Body != nil { + _, err := ioutil.ReadAll(r.Body) + assert.NotNil(t, err) + } + }), "test_handler", + WithTracerProvider(provider), + ) + + r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil) + if err != nil { + t.Fatal(err) + } + h.ServeHTTP(rr, r) + assert.Equal(t, 200, rr.Result().StatusCode) +}