diff --git a/http/server/response_writer.go b/http/server/response_writer.go index f11f2fa..99c6aa7 100644 --- a/http/server/response_writer.go +++ b/http/server/response_writer.go @@ -8,11 +8,12 @@ import ( ) type TrackingResponseWriter struct { - track *tracking - recordHeaders bool - rw http.ResponseWriter - flusher http.Flusher - hijacker http.Hijacker + track *tracking + recordHeaders bool + rw http.ResponseWriter + flusher http.Flusher + hijacker http.Hijacker + hijackCallback func() } func (w *TrackingResponseWriter) gatherHeaders() { @@ -51,10 +52,18 @@ func (w *TrackingResponseWriter) WriteHeader(statusCode int) { } func (w *TrackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + var c net.Conn + var rw *bufio.ReadWriter if w.hijacker != nil { - return w.hijacker.Hijack() + c, rw, w.track.hijackedErr = w.hijacker.Hijack() + } else { + w.track.hijackedErr = fmt.Errorf("not implements Hijacker interface") } - return nil, nil, fmt.Errorf("not implemented") + w.track.isHijacked = true + if w.hijackCallback != nil { + w.hijackCallback() + } + return c, rw, w.track.hijackedErr } func (w *TrackingResponseWriter) Flush() { @@ -63,12 +72,14 @@ func (w *TrackingResponseWriter) Flush() { } } -func newTrackingResponseWriter(rw http.ResponseWriter, t *tracking, recordHeaders bool) *TrackingResponseWriter { +func newTrackingResponseWriter(rw http.ResponseWriter, t *tracking, recordHeaders bool, + hijackCallback func()) *TrackingResponseWriter { return &TrackingResponseWriter{ - track: t, - recordHeaders: recordHeaders, - rw: rw, - flusher: rw.(http.Flusher), - hijacker: rw.(http.Hijacker), + track: t, + recordHeaders: recordHeaders, + rw: rw, + flusher: rw.(http.Flusher), + hijacker: rw.(http.Hijacker), + hijackCallback: hijackCallback, } } diff --git a/http/server/server.go b/http/server/server.go index 4e4d1c5..6ca2b1d 100644 --- a/http/server/server.go +++ b/http/server/server.go @@ -34,7 +34,11 @@ func (h *trackingHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { r = r.WithContext(t.ctx) if h.metrics != nil || h.traces != nil { - rw = newTrackingResponseWriter(rw, t, h.reportHeaders) + rw = newTrackingResponseWriter(rw, t, h.reportHeaders, func() { + t.Finish() + h.traces.end(t) + h.metrics.report(t, r) + }) } t.Start() diff --git a/http/server/traces.go b/http/server/traces.go index e0fc5b2..523e7d5 100644 --- a/http/server/traces.go +++ b/http/server/traces.go @@ -60,6 +60,13 @@ func (t *tracesHTTP) end(tr *tracking) { return } + if tr.isHijacked { + tr.span.SetAttributes(attribute.Bool("http.connection.hijacked", true)) + if tr.hijackedErr != nil { + tr.span.SetAttributes(attribute.String("http.connection.error", tr.hijackedErr.Error())) + } + } + tr.span.SetAttributes( semconv.HTTPRoute(tr.EndpointPattern()), semconv.HTTPResponseStatusCode(tr.responseStatus), diff --git a/http/server/tracking.go b/http/server/tracking.go index c1325b4..429ae77 100644 --- a/http/server/tracking.go +++ b/http/server/tracking.go @@ -27,10 +27,15 @@ type tracking struct { responseHeaders map[string][]string writeErrs []error endpointPattern string + isHijacked bool + hijackedErr error } func (t *tracking) EndpointPattern() string { if len(t.endpointPattern) == 0 { + if t.isHijacked { + return "Upgraded Connection" + } return "404 Not Found" } return t.endpointPattern