diff --git a/httpgrpc/server/server.go b/httpgrpc/server/server.go index 0ccad0a6..8734afbf 100644 --- a/httpgrpc/server/server.go +++ b/httpgrpc/server/server.go @@ -14,6 +14,7 @@ import ( "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" "github.com/mwitkow/go-grpc-middleware" "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" "github.com/sercand/kuberesolver" "golang.org/x/net/context" "google.golang.org/grpc" @@ -41,8 +42,14 @@ func (s Server) Handle(ctx context.Context, r *httpgrpc.HTTPRequest) (*httpgrpc. if err != nil { return nil, err } - req = req.WithContext(ctx) toHeader(r.Headers, req.Header) + if tracer := opentracing.GlobalTracer(); tracer != nil { + if clientContext, err := tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)); err == nil { + span := tracer.StartSpan("httpgrpc", ext.RPCServerOption(clientContext)) + ctx = opentracing.ContextWithSpan(ctx, span) + } + } + req = req.WithContext(ctx) req.RequestURI = r.Url recorder := httptest.NewRecorder() s.handler.ServeHTTP(recorder, req) @@ -138,6 +145,11 @@ func (c *Client) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + if tracer := opentracing.GlobalTracer(); tracer != nil { + if span := opentracing.SpanFromContext(r.Context()); span != nil { + _ = tracer.Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header)) + } + } req := &httpgrpc.HTTPRequest{ Method: r.Method, Url: r.RequestURI, diff --git a/httpgrpc/server/server_test.go b/httpgrpc/server/server_test.go index 4cf1c867..e1fe4361 100644 --- a/httpgrpc/server/server_test.go +++ b/httpgrpc/server/server_test.go @@ -10,8 +10,10 @@ import ( "reflect" "testing" + opentracing "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + jaegercfg "github.com/uber/jaeger-client-go/config" "github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/user" "google.golang.org/grpc" @@ -101,3 +103,34 @@ func TestParseURL(t *testing.T) { assert.Equal(t, tc.expected, got) } } + +func TestTracePropagation(t *testing.T) { + jaeger := jaegercfg.Configuration{} + closer, err := jaeger.InitGlobalTracer("test") + defer closer.Close() + require.NoError(t, err) + + server, err := newTestServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := opentracing.SpanFromContext(r.Context()) + fmt.Fprint(w, span.BaggageItem("name")) + })) + + require.NoError(t, err) + defer server.grpcServer.GracefulStop() + + client, err := NewClient(server.URL) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/hello", &bytes.Buffer{}) + require.NoError(t, err) + + sp, ctx := opentracing.StartSpanFromContext(context.Background(), "Test") + sp.SetBaggageItem("name", "world") + + req = req.WithContext(user.InjectOrgID(ctx, "1")) + recorder := httptest.NewRecorder() + client.ServeHTTP(recorder, req) + + assert.Equal(t, "world", string(recorder.Body.Bytes())) + assert.Equal(t, 200, recorder.Code) +}