From 9cb268a8b8b82aeead294b892cb0bd9568969c51 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 22 Jul 2020 08:31:02 +0200 Subject: [PATCH] change how tracing query comments are handled Signed-off-by: Andres Taylor --- go/trace/opentracing.go | 25 ++++++++++++------------- go/trace/opentracing_test.go | 18 +++++++++++++----- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/go/trace/opentracing.go b/go/trace/opentracing.go index 6e6c6c5bc4e..3a38b2647c2 100644 --- a/go/trace/opentracing.go +++ b/go/trace/opentracing.go @@ -17,13 +17,13 @@ limitations under the License. package trace import ( - "strings" + "encoding/base64" + "encoding/json" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "golang.org/x/net/context" "google.golang.org/grpc" - "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) @@ -86,19 +86,18 @@ func (jf openTracingService) New(parent Span, label string) Span { } func extractMapFromString(in string) (opentracing.TextMapCarrier, error) { - m := make(opentracing.TextMapCarrier) - items := strings.Split(in, ":") - if len(items) < 2 { - return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "expected transmitted context to contain at least span id and trace id") + decodedBytes, err := base64.StdEncoding.DecodeString(in) + if err != nil { + return nil, err } - for _, v := range items { - idx := strings.Index(v, "=") - if idx < 1 { - return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "every element in the context string has to be in the form key=value") - } - m[v[0:idx]] = v[idx+1:] + + var dat opentracing.TextMapCarrier + err = json.Unmarshal(decodedBytes, &dat) + if err != nil { + return nil, err } - return m, nil + + return dat, nil } func (jf openTracingService) NewFromString(parent, label string) (Span, error) { diff --git a/go/trace/opentracing_test.go b/go/trace/opentracing_test.go index 19bdbce9019..104545fe657 100644 --- a/go/trace/opentracing_test.go +++ b/go/trace/opentracing_test.go @@ -17,6 +17,8 @@ limitations under the License. package trace import ( + "encoding/base64" + "encoding/json" "testing" "github.com/opentracing/opentracing-go" @@ -25,17 +27,23 @@ import ( func TestExtractMapFromString(t *testing.T) { expected := make(opentracing.TextMapCarrier) - expected["apa"] = "12" - expected["banan"] = "x-tracing-backend-12" - result, err := extractMapFromString("apa=12:banan=x-tracing-backend-12") + expected["uber-trace-id"] = "123:456:789:1" + expected["other data with weird symbols:!#;"] = ":1!\"" + jsonBytes, err := json.Marshal(expected) + assert.NoError(t, err) + + encodedString := base64.StdEncoding.EncodeToString(jsonBytes) + + result, err := extractMapFromString(encodedString) assert.NoError(t, err) assert.Equal(t, expected, result) } func TestErrorConditions(t *testing.T) { - _, err := extractMapFromString("") + encodedString := base64.StdEncoding.EncodeToString([]byte(`{"key":42}`)) + _, err := extractMapFromString(encodedString) // malformed json {"key":42} assert.Error(t, err) - _, err = extractMapFromString("key=value:keywithnovalue") + _, err = extractMapFromString("this is not base64") // malformed base64 assert.Error(t, err) }