diff --git a/client.go b/client.go index 455c205..1354ba6 100644 --- a/client.go +++ b/client.go @@ -232,7 +232,7 @@ func errDeadlineOrCancel(err error) bool { func doSpan(ctx context.Context, req *http.Request) string { trace := newTraceFromCtx(ctx) span := trace.span() - if trace.received || log.IsLevelEnabled(log.TraceLevel) { + if trace.received || isTraced { span.setHeader(req.Header) } return span.string() diff --git a/ctx.go b/ctx.go index 3c71c61..40f6a34 100644 --- a/ctx.go +++ b/ctx.go @@ -33,7 +33,7 @@ type Lambda struct { } func newLambda(w http.ResponseWriter, r *http.Request, vars map[string]string) *Lambda { - return &Lambda{w: w, r: r, trace: newTrace(r), vars: vars} + return &Lambda{w: w, r: r, trace: newTraceFromHeader(r), vars: vars} } // NewRequestCtx adds request related data to r.Context(). @@ -137,3 +137,9 @@ func (l *Lambda) ResponseHeaderAddAs(header, value string) { h := l.w.Header() h[header] = append(h[header], value) } + +// TraceID returns trace ID of Lambda context. +// That trace ID is either received in request or generated when Lambda context is created. +func (l *Lambda) TraceID() string { + return l.trace.traceID() +} diff --git a/trace.go b/trace.go index 9b38a0a..26fbe0d 100644 --- a/trace.go +++ b/trace.go @@ -13,41 +13,48 @@ import ( log "github.com/sirupsen/logrus" ) +var isTraced bool = true + +// SetTrace can enable/disable tracing in restful. By default tracing is enabled +func SetTrace(b bool) { + isTraced = b +} + type trace struct { parent *traceParent b3 *traceB3 received bool } -// newTrace creates new trace object. Never returns nil. -func newTrace(r *http.Request) *trace { +// newTraceFromHeader creates new trace object. If no trace data, then create random. Never returns nil. +func newTraceFromHeader(r *http.Request) *trace { t := trace{parent: newTraceParent(r), b3: newTraceB3(r)} t.received = t.valid() - if !t.received { // Create fake one. Saved to r (ctx.r), so that any client to be able to find it. Note: Logger may have created one already. - t.parent = newTraceParentFromFake(r) - if !t.valid() { - return newTraceRandom() - } + if !t.received { + return newTraceRandom() } return &t } func newTraceRandom() *trace { - if log.IsLevelEnabled(log.TraceLevel) { - traceID := randStr32() - return &trace{parent: newTraceParentWithID(traceID), b3: newTraceB3WithID(traceID, true)} - } - return &trace{b3: &traceB3{spanID: randStr16()}} + debug := log.IsLevelEnabled(log.TraceLevel) + traceID := randStr32() + return &trace{b3: newTraceB3WithID(traceID, debug)} } // newTraceFromCtx creates new trace object, preferably from context. Never returns nil. func newTraceFromCtx(ctx context.Context) *trace { l := L(ctx) - if l == nil || l.trace == nil { + if l == nil { return newTraceRandom() } + + if !l.trace.valid() { + l.trace = newTraceRandom() // Updates trace in ctx via l.trace pointer. + } + return l.trace } diff --git a/trace_test.go b/trace_test.go index a889d13..aab7ef1 100644 --- a/trace_test.go +++ b/trace_test.go @@ -18,7 +18,7 @@ func TestRecvdParent(t *testing.T) { assert := assert.New(t) r, _ := http.NewRequest("POST", "", nil) r.Header.Set("traceparent", "00-0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-01") - trace := newTrace(r) + trace := newTraceFromHeader(r) assert.True(trace.received) assert.Equal("00-0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-01", trace.string()) span := trace.span().string() @@ -26,23 +26,11 @@ func TestRecvdParent(t *testing.T) { assert.Equal(55, len(span)) } -func TestFakeParent(t *testing.T) { - assert := assert.New(t) - r, _ := http.NewRequest("POST", "", nil) - r.Header.Set("x-fake-traceparent", "00-0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-01") - trace := newTrace(r) - assert.False(trace.received) - assert.Equal("00-0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-01", trace.string()) - span := trace.span().string() - assert.Contains(span, "00-0af7651916cd43dd8448eb211c80319c-") - assert.Equal(55, len(span)) -} - func TestRecvdBadParent(t *testing.T) { assert := assert.New(t) r, _ := http.NewRequest("POST", "", nil) r.Header.Set("traceparent", "FF-0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-01") - trace := newTrace(r) + trace := newTraceFromHeader(r) assert.False(trace.received) assert.NotContains(trace.string(), "-0af7651916cd43dd8448eb211c80319c-") } @@ -63,7 +51,7 @@ func TestB3SingleLine(t *testing.T) { r, _ := http.NewRequest("POST", "", nil) traceStr := "0af7651916cd43dd8448eb211c80319c-b9c7c989f97918e1-1-deadbeef87654321" r.Header.Set("b3", traceStr) - trace := newTrace(r) + trace := newTraceFromHeader(r) assert.True(trace.received) assert.Contains(trace.string(), "0af7651916cd43dd8448eb211c80319c") headers := http.Header{} @@ -73,11 +61,10 @@ func TestB3SingleLine(t *testing.T) { func TestTracePropagation(t *testing.T) { assert := assert.New(t) - log.SetLevel(log.TraceLevel) // That switches on trace generation and propagation // Server srvURL := "" - traceid := "" + traceID := "" prevSpanID := "" parents := make(map[string]bool) depth := 0 @@ -87,17 +74,16 @@ func TestTracePropagation(t *testing.T) { t := newTraceFromCtx(ctx) assert.True(t.received) if depth == 0 { - traceid = t.traceID() + traceID = t.traceID() prevSpanID = t.spanID() parents[t.string()] = true t.b3.sampled = "1" t.b3.requestID = "req" t.b3.spanCtx = "ctx" } else { - assert.Equal(traceid, t.parent.traceID()) + assert.Equal(traceID, t.traceID()) + assert.Equal(traceID, L(ctx).TraceID()) assert.NotContains(parents, t.string()) - assert.Equal(t.parent.traceID(), t.b3.traceID) - assert.Equal(t.parent.spanID(), t.b3.spanID) assert.Equal(prevSpanID, t.b3.parentSpanID) assert.Equal("1", t.b3.sampled) assert.Equal("req", t.b3.requestID) diff --git a/traceb3.go b/traceb3.go index 078c7af..3392aba 100644 --- a/traceb3.go +++ b/traceb3.go @@ -89,9 +89,9 @@ func newTraceB3(r *http.Request) *traceB3 { return b3 } -func newTraceB3WithID(traceID string, trace bool) *traceB3 { +func newTraceB3WithID(traceID string, debug bool) *traceB3 { b3 := traceB3{traceID: traceID, singleLine: true} - if trace { + if debug { b3.sampled = "d" } return &b3 diff --git a/traceparent.go b/traceparent.go index b4e5968..8a008f7 100644 --- a/traceparent.go +++ b/traceparent.go @@ -5,7 +5,6 @@ package restful import ( - "fmt" "net/http" "strings" ) @@ -15,9 +14,8 @@ See https://www.w3.org/TR/trace-context */ const ( - headerTraceParent = "traceparent" - headerTraceState = "tracestate" - headerFakeTraceParent = "x-fake-traceparent" + headerTraceParent = "traceparent" + headerTraceState = "tracestate" ) type traceParent struct { @@ -29,14 +27,6 @@ func newTraceParent(r *http.Request) *traceParent { // May return nil. return newTraceParentFromHeaderValue(r.Header.Get(headerTraceParent), r.Header.Get(headerTraceState)) } -func newTraceParentFromFake(r *http.Request) *traceParent { - return newTraceParentFromHeaderValue(r.Header.Get(headerFakeTraceParent), "") // Our server logger may have faked one already. -} - -func newTraceParentWithID(traceID string) *traceParent { - return &traceParent{parent: []string{"00", traceID, fmt.Sprintf("%016x", 0) /*invalid, span resolves that*/, "00"}} -} - func newTraceParentFromHeaderValue(traceparent, tracestate string) *traceParent { parent := strings.Split(traceparent, "-") if len(parent) != 4 {