diff --git a/README.md b/README.md index 28cacc2..140e116 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,11 @@ You can also create tracer with sampling rate. tracer, err := go2sky.NewTracer("example", go2sky.WithReporter(r), go2sky.WithSampler(0.5)) ``` +Also could customize correlation context config. +```go +tracer, err := go2sky.NewTracer("example", go2sky.WithReporter(r), go2sky.WithSampler(0.5), go2sky.WithCorrelation(3, 128)) +``` + ## Create span To create a span in a trace, we used the `Tracer` to start a new span. We indicate this as the root span because of @@ -67,6 +72,22 @@ A sub span created as the children of root span links to its parent with `Contex subSpan, newCtx, err := tracer.CreateLocalSpan(ctx) ``` +## Get correlation + +Get custom data from tracing context. + +```go +value := go2sky.GetCorrelation(ctx, key) +``` + +## Put correlation + +Put custom data to tracing context. + +```go +success := go2sky.PutCorrelation(ctx, key, value) +``` + ## End span We must end the spans so they becomes available for sending to the backend by a reporter. @@ -109,16 +130,16 @@ upstream service. ```go //Extract context from HTTP request header `sw8` -span, ctx, err := tracer.CreateEntrySpan(r.Context(), "/api/login", func() (string, error) { - return r.Header.Get("sw8"), nil +span, ctx, err := tracer.CreateEntrySpan(r.Context(), "/api/login", func(key string) (string, error) { + return r.Header.Get(key), nil }) // Some operation ... // Inject context into HTTP request header `sw8` -span, err := tracer.CreateExitSpan(req.Context(), "/service/validate", "tomcat-service:8080", func(header string) error { - req.Header.Set(propagation.Header, header) +span, err := tracer.CreateExitSpan(req.Context(), "/service/validate", "tomcat-service:8080", func(key, value string) error { + req.Header.Set(key, value) return nil }) ``` diff --git a/correlation.go b/correlation.go new file mode 100644 index 0000000..40169e6 --- /dev/null +++ b/correlation.go @@ -0,0 +1,84 @@ +// Licensed to SkyAPM org under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. SkyAPM org licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package go2sky + +import "context" + +type CorrelationConfig struct { + MaxKeyCount int + MaxValueSize int +} + +func WithCorrelation(keyCount, valueSize int) TracerOption { + return func(t *Tracer) { + t.correlation = &CorrelationConfig{ + MaxKeyCount: keyCount, + MaxValueSize: valueSize, + } + } +} + +func PutCorrelation(ctx context.Context, key, value string) bool { + if key == "" { + return false + } + + activeSpan := ctx.Value(ctxKeyInstance) + if activeSpan == nil { + return false + } + + span, ok := activeSpan.(segmentSpan) + if !ok { + return false + } + correlationContext := span.context().CorrelationContext + // remove key + if value == "" { + delete(correlationContext, key) + return true + } + // out of max value size + if len(value) > span.tracer().correlation.MaxValueSize { + return false + } + // already exists key + if _, ok := correlationContext[key]; ok { + correlationContext[key] = value + return true + } + // out of max key count + if len(correlationContext) >= span.tracer().correlation.MaxKeyCount { + return false + } + span.context().CorrelationContext[key] = value + return true +} + +func GetCorrelation(ctx context.Context, key string) string { + activeSpan := ctx.Value(ctxKeyInstance) + if activeSpan == nil { + return "" + } + + span, ok := activeSpan.(segmentSpan) + if !ok { + return "" + } + return span.context().CorrelationContext[key] +} diff --git a/correlation_test.go b/correlation_test.go new file mode 100644 index 0000000..ee94633 --- /dev/null +++ b/correlation_test.go @@ -0,0 +1,206 @@ +// Licensed to SkyAPM org under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. SkyAPM org licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package go2sky_test + +import ( + "context" + "log" + "reflect" + "testing" + + "github.com/SkyAPM/go2sky" + "github.com/SkyAPM/go2sky/propagation" + "github.com/SkyAPM/go2sky/reporter" +) + +const ( + correlationTestKey = "test-key" + correlationTestValue = "test-value" +) + +func TestGetCorrelation_WithTracingContest(t *testing.T) { + verifyPutResult := func(ctx context.Context, key, value string, result bool, t *testing.T) { + if success := go2sky.PutCorrelation(ctx, key, value); success != result { + t.Errorf("put correlation result is not right: %t", success) + } + } + tests := []struct { + name string + // extract from context + extractor propagation.Extractor + // extract correlation context + extracted map[string]string + // put correlation + customCase func(ctx context.Context, t *testing.T) + // after exported correaltion context + want map[string]string + }{ + { + name: "no context", + extractor: func(headerKey string) (string, error) { + return "", nil + }, + extracted: make(map[string]string), + customCase: func(ctx context.Context, t *testing.T) { + verifyPutResult(ctx, correlationTestKey, correlationTestValue, true, t) + }, + want: func() map[string]string { + m := make(map[string]string) + m[correlationTestKey] = correlationTestValue + return m + }(), + }, + { + name: "existing context with correlation", + extractor: func(headerKey string) (string, error) { + if headerKey == propagation.HeaderCorrelation { + // test1 = t1 + return "dGVzdDE=:dDE=", nil + } + if headerKey == propagation.Header { + return "1-MWYyZDRiZjQ3YmY3MTFlYWI3OTRhY2RlNDgwMDExMjI=-MWU3YzIwNGE3YmY3MTFlYWI4NThhY2RlNDgwMDExMjI=" + + "-0-c2VydmljZQ==-aW5zdGFuY2U=-cHJvcGFnYXRpb24=-cHJvcGFnYXRpb246NTU2Ng==", nil + } + return "", nil + }, + extracted: func() map[string]string { + m := make(map[string]string) + m["test1"] = "t1" + return m + }(), + customCase: func(ctx context.Context, t *testing.T) { + verifyPutResult(ctx, correlationTestKey, correlationTestValue, true, t) + }, + want: func() map[string]string { + m := make(map[string]string) + m[correlationTestKey] = correlationTestValue + m["test1"] = "t1" + return m + }(), + }, + { + name: "empty context with put bound judge", + extractor: func(headerKey string) (string, error) { + return "", nil + }, + customCase: func(ctx context.Context, t *testing.T) { + // empty key + verifyPutResult(ctx, "", "123", false, t) + + // remove key + verifyPutResult(ctx, correlationTestKey, correlationTestValue, true, t) + verifyPutResult(ctx, correlationTestKey, "", true, t) + if go2sky.GetCorrelation(ctx, correlationTestKey) != "" { + t.Errorf("correlation test key should be null") + } + + // out of max value size + verifyPutResult(ctx, "test-key", "1234567890123456", false, t) + + // out of key count + verifyPutResult(ctx, "test-key1", "123", true, t) + verifyPutResult(ctx, "test-key2", "123", true, t) + verifyPutResult(ctx, "test-key3", "123", true, t) + verifyPutResult(ctx, "test-key4", "123", false, t) + + // exists key + verifyPutResult(ctx, "test-key1", "123456", true, t) + }, + want: func() map[string]string { + m := make(map[string]string) + m["test-key1"] = "123456" + m["test-key2"] = "123" + m["test-key3"] = "123" + return m + }(), + }, + } + + r, err := reporter.NewLogReporter() + if err != nil { + log.Fatalf("new reporter error %v \n", err) + } + defer r.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + tracer, _ := go2sky.NewTracer("correlationTest", go2sky.WithReporter(r), go2sky.WithSampler(1), go2sky.WithCorrelation(3, 10)) + + // create entry span from extractor + span, ctx, _ := tracer.CreateEntrySpan(ctx, "test-entry", tt.extractor) + defer span.End() + + // verify extracted context is same + if tt.extracted != nil { + for key, value := range tt.extracted { + if go2sky.GetCorrelation(ctx, key) != value { + t.Errorf("error get previous correlation value, current is: %s", go2sky.GetCorrelation(ctx, key)) + } + } + } + + // custom case + tt.customCase(ctx, t) + + // put sample local span + span, ctx, _ = tracer.CreateLocalSpan(ctx) + defer span.End() + + // validate correlation context + // verify extracted context is same + for key, value := range tt.want { + if go2sky.GetCorrelation(ctx, key) != value { + t.Errorf("error validate correlation value, current is: %s", go2sky.GetCorrelation(ctx, key)) + } + } + + // export context + scx := propagation.SpanContext{} + _, err := tracer.CreateExitSpan(ctx, "test-exit", "127.0.0.1:8080", func(headerKey, headerValue string) error { + if headerKey == propagation.HeaderCorrelation { + err = scx.DecodeSW8Correlation(headerValue) + if err != nil { + t.Fail() + } + } + return nil + }) + if err != nil { + t.Fail() + } + reflect.DeepEqual(scx, tt.want) + }) + } +} + +func TestGetCorrelation_WithEmptyContext(t *testing.T) { + emptyValue := go2sky.GetCorrelation(context.Background(), "empty-key") + if emptyValue != "" { + t.Errorf("should be empty value") + } + + success := go2sky.PutCorrelation(context.Background(), "empty-key", "empty-value") + if success { + t.Errorf("put correlation key should be failed") + } + + emptyValue = go2sky.GetCorrelation(context.Background(), "empty-key") + if emptyValue != "" { + t.Errorf("should be empty value") + } +} diff --git a/noop_test.go b/noop_test.go index 25ea834..a457174 100644 --- a/noop_test.go +++ b/noop_test.go @@ -36,7 +36,7 @@ func TestCreateNoopSpan(t *testing.T) { { "Entry", func() (Span, context.Context, error) { - return tracer.CreateEntrySpan(context.Background(), "entry", func() (s string, e error) { + return tracer.CreateEntrySpan(context.Background(), "entry", func(key string) (s string, e error) { return "", nil }) }, @@ -44,7 +44,7 @@ func TestCreateNoopSpan(t *testing.T) { { "Exit", func() (s Span, c context.Context, err error) { - s, err = tracer.CreateExitSpan(context.Background(), "exit", "localhost:8080", func(header string) error { + s, err = tracer.CreateExitSpan(context.Background(), "exit", "localhost:8080", func(key, value string) error { return nil }) return @@ -72,13 +72,13 @@ func TestCreateNoopSpan(t *testing.T) { func TestNoopSpanFromBegin(t *testing.T) { tracer, _ := NewTracer("service") - span, ctx, _ := tracer.CreateEntrySpan(context.Background(), "entry", func() (s string, e error) { + span, ctx, _ := tracer.CreateEntrySpan(context.Background(), "entry", func(key string) (s string, e error) { return "", nil }) if _, ok := span.(*NoopSpan); !ok { t.Error("Should create noop span") } - exitSpan, _ := tracer.CreateExitSpan(ctx, "exit", "localhost:8080", func(header string) error { + exitSpan, _ := tracer.CreateExitSpan(ctx, "exit", "localhost:8080", func(key, value string) error { return nil }) if _, ok := exitSpan.(*NoopSpan); !ok { diff --git a/plugins/http/client.go b/plugins/http/client.go index d1f7712..7208e42 100644 --- a/plugins/http/client.go +++ b/plugins/http/client.go @@ -23,7 +23,6 @@ import ( "time" "github.com/SkyAPM/go2sky" - "github.com/SkyAPM/go2sky/propagation" v3 "github.com/SkyAPM/go2sky/reporter/grpc/language-agent" ) @@ -92,8 +91,8 @@ type transport struct { } func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) { - span, err := t.tracer.CreateExitSpan(req.Context(), getOperationName(t.name, req), req.Host, func(header string) error { - req.Header.Set(propagation.Header, header) + span, err := t.tracer.CreateExitSpan(req.Context(), getOperationName(t.name, req), req.Host, func(key, value string) error { + req.Header.Set(key, value) return nil }) if err != nil { diff --git a/plugins/http/server.go b/plugins/http/server.go index cef346c..2307787 100644 --- a/plugins/http/server.go +++ b/plugins/http/server.go @@ -25,7 +25,6 @@ import ( "github.com/SkyAPM/go2sky" "github.com/SkyAPM/go2sky/internal/tool" - "github.com/SkyAPM/go2sky/propagation" v3 "github.com/SkyAPM/go2sky/reporter/grpc/language-agent" ) @@ -81,8 +80,8 @@ func NewServerMiddleware(tracer *go2sky.Tracer, options ...ServerOption) (func(h // ServeHTTP implements http.Handler. func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - span, ctx, err := h.tracer.CreateEntrySpan(r.Context(), getOperationName(h.name, r), func() (string, error) { - return r.Header.Get(propagation.Header), nil + span, ctx, err := h.tracer.CreateEntrySpan(r.Context(), getOperationName(h.name, r), func(key string) (string, error) { + return r.Header.Get(key), nil }) if err != nil { if h.next != nil { diff --git a/propagation/propagation.go b/propagation/propagation.go index 9b4a499..793eb42 100644 --- a/propagation/propagation.go +++ b/propagation/propagation.go @@ -31,9 +31,12 @@ import ( ) const ( - Header string = "sw8" - headerLen int = 8 - splitToken string = "-" + Header string = "sw8" + HeaderCorrelation string = "sw8-correlation" + headerLen int = 8 + splitToken string = "-" + correlationSplitToken string = "," + correlationKeyValueSplitToken string = ":" ) var ( @@ -43,22 +46,56 @@ var ( // Extractor is a tool specification which define how to // extract trace parent context from propagation context -type Extractor func() (string, error) +type Extractor func(headerKey string) (string, error) // Injector is a tool specification which define how to // inject trace context into propagation context -type Injector func(header string) error +type Injector func(headerKey, headerValue string) error // SpanContext defines propagation specification of SkyWalking type SpanContext struct { - TraceID string `json:"trace_id"` - ParentSegmentID string `json:"parent_segment_id"` - ParentService string `json:"parent_service"` - ParentServiceInstance string `json:"parent_service_instance"` - ParentEndpoint string `json:"parent_endpoint"` - AddressUsedAtClient string `json:"address_used_at_client"` - ParentSpanID int32 `json:"parent_span_id"` - Sample int8 `json:"sample"` + TraceID string `json:"trace_id"` + ParentSegmentID string `json:"parent_segment_id"` + ParentService string `json:"parent_service"` + ParentServiceInstance string `json:"parent_service_instance"` + ParentEndpoint string `json:"parent_endpoint"` + AddressUsedAtClient string `json:"address_used_at_client"` + ParentSpanID int32 `json:"parent_span_id"` + Sample int8 `json:"sample"` + Valid bool `json:"valid"` + CorrelationContext map[string]string `json:"correlation_context"` +} + +// Decode all SpanContext data from Extractor +func (tc *SpanContext) Decode(extractor Extractor) error { + tc.Valid = false + // sw8 + err := tc.decode(extractor, Header, tc.DecodeSW8) + if err != nil { + return err + } + + // correlation + err = tc.decode(extractor, HeaderCorrelation, tc.DecodeSW8Correlation) + if err != nil { + return err + } + return nil +} + +// Encode all SpanContext data to Injector +func (tc *SpanContext) Encode(injector Injector) error { + // sw8 + err := injector(Header, tc.EncodeSW8()) + if err != nil { + return err + } + // correlation + err = injector(HeaderCorrelation, tc.EncodeSW8Correlation()) + if err != nil { + return err + } + return nil } // DecodeSW6 converts string header to SpanContext @@ -103,6 +140,7 @@ func (tc *SpanContext) DecodeSW8(header string) error { if err != nil { return errors.Wrap(err, "network address parse error") } + tc.Valid = true return nil } @@ -120,6 +158,46 @@ func (tc *SpanContext) EncodeSW8() string { }, "-") } +// DecodeSW8Correlation converts correlation string header to SpanContext +func (tc *SpanContext) DecodeSW8Correlation(header string) error { + tc.CorrelationContext = make(map[string]string) + if header == "" { + return nil + } + + hh := strings.Split(header, correlationSplitToken) + for inx := range hh { + keyValues := strings.Split(hh[inx], correlationKeyValueSplitToken) + if len(keyValues) != 2 { + continue + } + decodedKey, err := decodeBase64(keyValues[0]) + if err != nil { + continue + } + decodedValue, err := decodeBase64(keyValues[1]) + if err != nil { + continue + } + + tc.CorrelationContext[decodedKey] = decodedValue + } + return nil +} + +// EncodeSW8Correlation converts correlation to string header +func (tc *SpanContext) EncodeSW8Correlation() string { + if len(tc.CorrelationContext) == 0 { + return "" + } + + content := make([]string, 0, len(tc.CorrelationContext)) + for k, v := range tc.CorrelationContext { + content = append(content, fmt.Sprintf("%s%s%s", encodeBase64(k), correlationKeyValueSplitToken, encodeBase64(v))) + } + return strings.Join(content, correlationSplitToken) +} + func stringConvertInt32(str string) (int32, error) { i, err := strconv.ParseInt(str, 0, 32) return int32(i), err @@ -136,3 +214,18 @@ func decodeBase64(str string) (string, error) { func encodeBase64(str string) string { return base64.StdEncoding.EncodeToString([]byte(str)) } + +func (tc *SpanContext) decode(extractor Extractor, headerKey string, decoder func(header string) error) error { + val, err := extractor(headerKey) + if err != nil { + return err + } + if val == "" { + return nil + } + err = decoder(val) + if err != nil { + return err + } + return nil +} diff --git a/propagation/propagation_test.go b/propagation/propagation_test.go index f551684..b293c09 100644 --- a/propagation/propagation_test.go +++ b/propagation/propagation_test.go @@ -17,7 +17,10 @@ package propagation -import "testing" +import ( + "reflect" + "testing" +) type fields struct { TraceID string @@ -152,3 +155,84 @@ func TestSpanContext_EncodeSW8(t *testing.T) { }) } } + +func TestSpanContext_DecodeSw8Correlation(t *testing.T) { + tests := []struct { + name string + args args + data map[string]string + }{ + { + name: "Empty Header", + args: args{header: ""}, + data: make(map[string]string), + }, + { + name: "Insufficient Header Entities", + args: args{header: "dGVzdC1rZXk="}, + data: make(map[string]string), + }, + { + name: "normal", + args: args{header: "dGVzdC1rZXk=:dGVzdC12YWx1ZQ=="}, + data: func() map[string]string { + m := make(map[string]string) + m["test-key"] = "test-value" + return m + }(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &SpanContext{} + if err := tc.DecodeSW8Correlation(tt.args.header); err != nil { + t.Errorf("DecodeSW8() error = %v, wantErr %v", err, err) + } + if !reflect.DeepEqual(tc.CorrelationContext, tt.data) { + t.Fail() + } + }) + } +} + +func TestSpanContext_EncodeSW8Correlation(t *testing.T) { + tests := []struct { + name string + data map[string]string + want string + }{ + { + name: "empty", + data: make(map[string]string), + want: "", + }, + { + name: "empty value", + data: func() map[string]string { + m := make(map[string]string) + m["test-key"] = "" + return m + }(), + want: "dGVzdC1rZXk=:", + }, + { + name: "normal", + data: func() map[string]string { + m := make(map[string]string) + m["test-key"] = "test-value" + return m + }(), + want: "dGVzdC1rZXk=:dGVzdC12YWx1ZQ==", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &SpanContext{ + CorrelationContext: tt.data, + } + if got := tc.EncodeSW8Correlation(); got != tt.want { + t.Errorf("EncodeSW8Correlation() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/reporter/grpc_test.go b/reporter/grpc_test.go index 46b13d8..6877c6b 100644 --- a/reporter/grpc_test.go +++ b/reporter/grpc_test.go @@ -72,17 +72,19 @@ func Test_e2e(t *testing.T) { if err != nil { t.Error(err) } - entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func() (string, error) { + entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func(key string) (string, error) { return header, nil }) if err != nil { t.Error(err) } - exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8787", func(head string) error { + exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8787", func(key, value string) error { scx := propagation.SpanContext{} - err = scx.DecodeSW8(head) - if err != nil { - t.Fatal(err) + if key == propagation.Header { + err = scx.DecodeSW8(value) + if err != nil { + t.Fatal(err) + } } return nil }) @@ -115,7 +117,7 @@ func TestGRPCReporter_Close(t *testing.T) { if err != nil { t.Error(err) } - entry, _, err := tracer.CreateEntrySpan(context.Background(), "/close", func() (s string, err error) { + entry, _, err := tracer.CreateEntrySpan(context.Background(), "/close", func(key string) (s string, err error) { return header, nil }) if err != nil { diff --git a/segment.go b/segment.go index 36a2c01..3eece7e 100644 --- a/segment.go +++ b/segment.go @@ -50,15 +50,16 @@ func newSegmentSpan(defaultSpan *defaultSpan, parentSpan segmentSpan) (s segment // SegmentContext is the context in a segment type SegmentContext struct { - TraceID string - SegmentID string - SpanID int32 - ParentSpanID int32 - ParentSegmentID string - collect chan<- ReportedSpan - refNum *int32 - spanIDGenerator *int32 - FirstSpan Span `json:"-"` + TraceID string + SegmentID string + SpanID int32 + ParentSpanID int32 + ParentSegmentID string + collect chan<- ReportedSpan + refNum *int32 + spanIDGenerator *int32 + FirstSpan Span `json:"-"` + CorrelationContext map[string]string } // ReportedSpan is accessed by Reporter to load reported data @@ -80,6 +81,7 @@ type ReportedSpan interface { type segmentSpan interface { Span context() SegmentContext + tracer() *Tracer segmentRegister() bool } @@ -150,6 +152,10 @@ func (s *segmentSpanImpl) context() SegmentContext { return s.SegmentContext } +func (s *segmentSpanImpl) tracer() *Tracer { + return s.defaultSpan.tracer +} + func (s *segmentSpanImpl) segmentRegister() bool { for { o := atomic.LoadInt32(s.Context().refNum) @@ -167,17 +173,20 @@ func (s *segmentSpanImpl) createSegmentContext(parent segmentSpan) (err error) { s.SegmentContext = SegmentContext{} if len(s.defaultSpan.Refs) > 0 { s.TraceID = s.defaultSpan.Refs[0].TraceID + s.CorrelationContext = s.defaultSpan.Refs[0].CorrelationContext } else { s.TraceID, err = idgen.GenerateGlobalID() if err != nil { return err } + s.CorrelationContext = make(map[string]string) } } else { s.SegmentContext = parent.context() s.ParentSegmentID = s.SegmentID s.ParentSpanID = s.SpanID s.SpanID = atomic.AddInt32(s.Context().spanIDGenerator, 1) + s.CorrelationContext = parent.context().CorrelationContext } if s.SegmentContext.FirstSpan == nil { s.SegmentContext.FirstSpan = s @@ -237,7 +246,7 @@ func newSegmentRoot(segmentSpan *segmentSpanImpl) *rootSegmentSpan { break } } - s.tracer.reporter.Send(append(s.segment, s)) + s.tracer().reporter.Send(append(s.segment, s)) }() return s } diff --git a/segment_test.go b/segment_test.go index 0b8a2a6..abb3220 100644 --- a/segment_test.go +++ b/segment_test.go @@ -156,11 +156,11 @@ func TestReportedSpan(t *testing.T) { } } -func MockExtractor() (c string, e error) { +func MockExtractor(key string) (c string, e error) { return } -func MockInjector(string) (e error) { +func MockInjector(key, value string) (e error) { return } diff --git a/trace.go b/trace.go index 826e722..51c4c9f 100644 --- a/trace.go +++ b/trace.go @@ -39,8 +39,9 @@ type Tracer struct { instance string reporter Reporter // 0 not init 1 init - initFlag int32 - sampler Sampler + initFlag int32 + sampler Sampler + correlation *CorrelationConfig } // TracerOption allows for functional options to adjust behaviour @@ -56,6 +57,8 @@ func NewTracer(service string, opts ...TracerOption) (tracer *Tracer, err error) service: service, initFlag: 0, } + // default correlation config + t.correlation = &CorrelationConfig{MaxKeyCount: 3, MaxValueSize: 128} for _, opt := range opts { opt(t) } @@ -86,17 +89,13 @@ func (t *Tracer) CreateEntrySpan(ctx context.Context, operationName string, extr if s, nCtx = t.createNoop(ctx); s != nil { return } - header, err := extractor() + var refSc = &propagation.SpanContext{} + err = refSc.Decode(extractor) if err != nil { return } - var refSc *propagation.SpanContext - if header != "" { - refSc = &propagation.SpanContext{} - err = refSc.DecodeSW8(header) - if err != nil { - return - } + if !refSc.Valid { + refSc = nil } s, nCtx, err = t.CreateLocalSpan(ctx, WithContext(refSc), WithSpanType(SpanTypeEntry), WithOperationName(operationName)) if err != nil { @@ -172,8 +171,9 @@ func (t *Tracer) CreateExitSpan(ctx context.Context, operationName string, peer spanContext.ParentServiceInstance = t.instance spanContext.ParentEndpoint = firstSpan.GetOperationName() spanContext.AddressUsedAtClient = peer + spanContext.CorrelationContext = span.Context().CorrelationContext - err = injector(spanContext.EncodeSW8()) + err = spanContext.Encode(injector) if err != nil { return nil, err } diff --git a/trace_propagation_test.go b/trace_propagation_test.go index 216aa91..aa7a486 100644 --- a/trace_propagation_test.go +++ b/trace_propagation_test.go @@ -59,17 +59,26 @@ func TestTracer_EntryAndExit(t *testing.T) { if err != nil { t.Error(err) } - entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func() (string, error) { + entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func(key string) (string, error) { return "", nil }) if err != nil { t.Error(err) } - exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8787", func(head string) error { + exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8787", func(key, value string) error { scx := propagation.SpanContext{} - err = scx.DecodeSW8(head) - if err != nil { - t.Fail() + if key == propagation.Header { + err = scx.DecodeSW8(value) + if err != nil { + t.Fail() + } + } + + if key == propagation.HeaderCorrelation { + err = scx.DecodeSW8Correlation(value) + if err != nil { + t.Fail() + } } return nil }) @@ -89,7 +98,7 @@ func TestTracer_Entry(t *testing.T) { if err != nil { t.Error(err) } - entrySpan, _, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func() (string, error) { + entrySpan, _, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func(key string) (string, error) { return header, nil }) if err != nil { @@ -113,50 +122,59 @@ func TestTracer_EntryAndExitInTrace(t *testing.T) { if err != nil { t.Error(err) } - entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func() (string, error) { + entrySpan, ctx, err := tracer.CreateEntrySpan(context.Background(), "/rest/api", func(key string) (string, error) { return header, nil }) if err != nil { t.Error(err) } - exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8786", func(head string) error { - sc := propagation.SpanContext{} - err = sc.DecodeSW8(head) - if err != nil { - t.Fail() + sc := propagation.SpanContext{} + exitSpan, err := tracer.CreateExitSpan(ctx, "/foo/bar", "foo.svc:8786", func(key, value string) error { + if key == propagation.Header { + err = sc.DecodeSW8(value) + if err != nil { + t.Fail() + } } - if sc.Sample != sample { - t.Fail() + if key == propagation.HeaderCorrelation { + err = sc.DecodeSW8Correlation(value) + if err != nil { + t.Fail() + } } + return nil + }) + if err != nil { + t.Error(err) + } - if sc.TraceID != traceID { - t.Fail() - } + if sc.Sample != sample { + t.Fail() + } - if sc.ParentSpanID != 1 { - t.Fail() - } + if sc.TraceID != traceID { + t.Fail() + } - if sc.ParentService != "service" { - t.Fail() - } + if sc.ParentSpanID != 1 { + t.Fail() + } - if sc.ParentServiceInstance != "instance" { - t.Fail() - } + if sc.ParentService != "service" { + t.Fail() + } - if sc.ParentEndpoint != "/rest/api" { - t.Fail() - } + if sc.ParentServiceInstance != "instance" { + t.Fail() + } - if sc.AddressUsedAtClient != "foo.svc:8786" { - t.Fail() - } - return nil - }) - if err != nil { - t.Error(err) + if sc.ParentEndpoint != "/rest/api" { + t.Fail() + } + + if sc.AddressUsedAtClient != "foo.svc:8786" { + t.Fail() } exitSpan.End() entrySpan.End() diff --git a/trace_test.go b/trace_test.go index 0ddc737..2abb406 100644 --- a/trace_test.go +++ b/trace_test.go @@ -181,7 +181,10 @@ func TestNewTracer(t *testing.T) { service string opts []TracerOption }{service: "test", opts: nil}, - &Tracer{service: "test", sampler: NewConstSampler(true)}, + &Tracer{service: "test", sampler: NewConstSampler(true), correlation: &CorrelationConfig{ + MaxKeyCount: 3, + MaxValueSize: 128, + }}, false, }, } @@ -216,7 +219,7 @@ func TestTracer_CreateEntrySpan_Parameter(t *testing.T) { ctx context.Context operationName string extractor propagation.Extractor - }{ctx: nil, operationName: "query type", extractor: func() (s string, e error) { + }{ctx: nil, operationName: "query type", extractor: func(key string) (s string, e error) { return "", nil }}, true, @@ -227,7 +230,7 @@ func TestTracer_CreateEntrySpan_Parameter(t *testing.T) { ctx context.Context operationName string extractor propagation.Extractor - }{ctx: context.Background(), operationName: "", extractor: func() (s string, e error) { + }{ctx: context.Background(), operationName: "", extractor: func(key string) (s string, e error) { return "", nil }}, true, @@ -247,7 +250,7 @@ func TestTracer_CreateEntrySpan_Parameter(t *testing.T) { ctx context.Context operationName string extractor propagation.Extractor - }{ctx: context.Background(), operationName: "query type", extractor: func() (s string, e error) { + }{ctx: context.Background(), operationName: "query type", extractor: func(key string) (s string, e error) { return "", nil }}, false, @@ -320,7 +323,7 @@ func TestTracer_CreateExitSpan_Parameter(t *testing.T) { operationName string peer string injector propagation.Injector - }{ctx: nil, operationName: "query type", peer: "localhost:8080", injector: func(header string) error { + }{ctx: nil, operationName: "query type", peer: "localhost:8080", injector: func(key, value string) error { return nil }}, true, @@ -332,7 +335,7 @@ func TestTracer_CreateExitSpan_Parameter(t *testing.T) { operationName string peer string injector propagation.Injector - }{ctx: context.Background(), operationName: "", peer: "localhost:8080", injector: func(header string) error { + }{ctx: context.Background(), operationName: "", peer: "localhost:8080", injector: func(key, value string) error { return nil }}, true, @@ -344,7 +347,7 @@ func TestTracer_CreateExitSpan_Parameter(t *testing.T) { operationName string peer string injector propagation.Injector - }{ctx: context.Background(), operationName: "query type", peer: "", injector: func(header string) error { + }{ctx: context.Background(), operationName: "query type", peer: "", injector: func(key, value string) error { return nil }}, true, @@ -366,7 +369,7 @@ func TestTracer_CreateExitSpan_Parameter(t *testing.T) { operationName string peer string injector propagation.Injector - }{ctx: context.Background(), operationName: "query type", peer: "localhost:8080", injector: func(header string) error { + }{ctx: context.Background(), operationName: "query type", peer: "localhost:8080", injector: func(key, value string) error { return nil }}, false,