Skip to content

Commit

Permalink
contrib/{aws/*,net/http}: fix encoding of http.url (#1662)
Browse files Browse the repository at this point in the history
This PR ensures auth strings are not present in URLs before attaching
them to spans in the `http.url` tag
  • Loading branch information
knusbaum committed Jan 17, 2023
1 parent 7fd5e0d commit 8753841
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 3 deletions.
5 changes: 4 additions & 1 deletion contrib/aws/aws-sdk-go-v2/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ func (mw *traceMiddleware) deserializeTraceMiddleware(stack *middleware.Stack) e

// Get values out of the request.
if req, ok := in.Request.(*smithyhttp.Request); ok {
// Make a copy of the URL so we don't modify the outgoing request
url := *req.URL
url.User = nil // Do not include userinfo in the HTTPURL tag.
span.SetTag(ext.HTTPMethod, req.Method)
span.SetTag(ext.HTTPURL, req.URL.String())
span.SetTag(ext.HTTPURL, url.String())
span.SetTag(tagAWSAgent, req.Header.Get("User-Agent"))
}

Expand Down
59 changes: 59 additions & 0 deletions contrib/aws/aws-sdk-go-v2/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ package aws

import (
"context"
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
Expand All @@ -17,6 +20,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAppendMiddleware(t *testing.T) {
Expand Down Expand Up @@ -199,3 +203,58 @@ func TestAppendMiddleware_WithOpts(t *testing.T) {
})
}
}

func TestHTTPCredentials(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

var auth string

server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if enc, ok := r.Header["Authorization"]; ok {
encoded := strings.TrimPrefix(enc[0], "Basic ")
if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil {
auth = string(b64)
}
}

w.Header().Set("X-Amz-RequestId", "test_req")
w.WriteHeader(200)
w.Write([]byte(`{}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)
u.User = url.UserPassword("myuser", "mypassword")

resolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
return aws.Endpoint{
PartitionID: "aws",
URL: u.String(),
SigningRegion: "eu-west-1",
}, nil
})

awsCfg := aws.Config{
Region: "eu-west-1",
Credentials: aws.AnonymousCredentials{},
EndpointResolver: resolver,
}

AppendMiddleware(&awsCfg)

sqsClient := sqs.NewFromConfig(awsCfg)
sqsClient.ListQueues(context.Background(), &sqs.ListQueuesInput{})

spans := mt.FinishedSpans()

s := spans[0]
assert.Equal(t, server.URL+"/", s.Tag(ext.HTTPURL))
assert.NotContains(t, s.Tag(ext.HTTPURL), "mypassword")
assert.NotContains(t, s.Tag(ext.HTTPURL), "myuser")
// Make sure we haven't modified the outgoing request, and the server still
// receives the auth request.
assert.Equal(t, auth, "myuser:mypassword")
}
5 changes: 4 additions & 1 deletion contrib/aws/aws-sdk-go/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ func (h *handlers) Send(req *request.Request) {
if req.RetryCount != 0 {
return
}
// Make a copy of the URL so we don't modify the outgoing request
url := *req.HTTPRequest.URL
url.User = nil // Do not include userinfo in the HTTPURL tag.
opts := []ddtrace.StartSpanOption{
tracer.SpanType(ext.SpanTypeHTTP),
tracer.ServiceName(h.serviceName(req)),
Expand All @@ -68,7 +71,7 @@ func (h *handlers) Send(req *request.Request) {
tracer.Tag(tagAWSOperation, h.awsOperation(req)),
tracer.Tag(tagAWSRegion, h.awsRegion(req)),
tracer.Tag(ext.HTTPMethod, req.Operation.HTTPMethod),
tracer.Tag(ext.HTTPURL, req.HTTPRequest.URL.String()),
tracer.Tag(ext.HTTPURL, url.String()),
tracer.Tag(ext.Component, "aws/aws-sdk-go/aws"),
tracer.Tag(ext.SpanKind, ext.SpanKindClient),
}
Expand Down
69 changes: 69 additions & 0 deletions contrib/aws/aws-sdk-go/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ package aws

import (
"context"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
Expand Down Expand Up @@ -182,3 +189,65 @@ func TestRetries(t *testing.T) {
assert.Len(t, mt.FinishedSpans(), 1)
assert.Equal(t, mt.FinishedSpans()[0].Tag(tagAWSRetryCount), 3)
}

func TestHTTPCredentials(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

var auth string

server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if enc, ok := r.Header["Authorization"]; ok {
encoded := strings.TrimPrefix(enc[0], "Basic ")
if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil {
auth = string(b64)
}
}

w.Header().Set("X-Amz-RequestId", "test_req")
w.WriteHeader(200)
w.Write([]byte(`{}`))
}))
defer server.Close()

u, err := url.Parse(server.URL)
require.NoError(t, err)
u.User = url.UserPassword("myuser", "mypassword")

resolver := endpoints.ResolverFunc(func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
return endpoints.ResolvedEndpoint{
PartitionID: "aws",
URL: u.String(),
SigningRegion: "eu-west-1",
}, nil
})

region := "eu-west-1"
awsCfg := aws.Config{
Region: &region,
Credentials: credentials.AnonymousCredentials,
EndpointResolver: resolver,
}
session := WrapSession(session.Must(session.NewSession(&awsCfg)))

ctx := context.Background()
s3api := s3.New(session)
req, _ := s3api.GetObjectRequest(&s3.GetObjectInput{
Bucket: aws.String("BUCKET"),
Key: aws.String("KEY"),
})
req.SetContext(ctx)
err = req.Send()
require.NoError(t, err)

spans := mt.FinishedSpans()

s := spans[0]
assert.Equal(t, server.URL+"/BUCKET/KEY", s.Tag(ext.HTTPURL))
assert.NotContains(t, s.Tag(ext.HTTPURL), "mypassword")
assert.NotContains(t, s.Tag(ext.HTTPURL), "myuser")
// Make sure we haven't modified the outgoing request, and the server still
// receives the auth request.
assert.Equal(t, auth, "myuser:mypassword")
}
5 changes: 4 additions & 1 deletion contrib/net/http/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er
return rt.base.RoundTrip(req)
}
resourceName := rt.cfg.resourceNamer(req)
// Make a copy of the URL so we don't modify the outgoing request
url := *req.URL
url.User = nil // Do not include userinfo in the HTTPURL tag.
opts := []ddtrace.StartSpanOption{
tracer.SpanType(ext.SpanTypeHTTP),
tracer.ResourceName(resourceName),
tracer.Tag(ext.HTTPMethod, req.Method),
tracer.Tag(ext.HTTPURL, req.URL.String()),
tracer.Tag(ext.HTTPURL, url.String()),
tracer.Tag(ext.Component, "net/http"),
tracer.Tag(ext.SpanKind, ext.SpanKindClient),
}
Expand Down
51 changes: 51 additions & 0 deletions contrib/net/http/roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
package http

import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"gopkg.in/DataDog/dd-trace-go.v1/ddtrace"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext"
Expand Down Expand Up @@ -179,6 +183,53 @@ func TestRoundTripperNetworkError(t *testing.T) {
assert.Equal(t, "net/http", s0.Tag(ext.Component))
}

func TestRoundTripperCredentials(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

var auth string
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if enc, ok := r.Header["Authorization"]; ok {
encoded := strings.TrimPrefix(enc[0], "Basic ")
if b64, err := base64.StdEncoding.DecodeString(encoded); err == nil {
auth = string(b64)
}
}

}))
defer s.Close()

rt := WrapRoundTripper(http.DefaultTransport,
WithBefore(func(req *http.Request, span ddtrace.Span) {
span.SetTag("CalledBefore", true)
}),
WithAfter(func(res *http.Response, span ddtrace.Span) {
span.SetTag("CalledAfter", true)
}))

client := &http.Client{
Transport: rt,
}

u, err := url.Parse(s.URL)
require.NoError(t, err)
u.User = url.UserPassword("myuser", "mypassword")

client.Get(u.String() + "/hello/world")

spans := mt.FinishedSpans()
require.Len(t, spans, 1)

s1 := spans[0]

assert.Equal(t, s.URL+"/hello/world", s1.Tag(ext.HTTPURL))
assert.NotContains(t, s1.Tag(ext.HTTPURL), "mypassword")
assert.NotContains(t, s1.Tag(ext.HTTPURL), "myuser")
// Make sure we haven't modified the outgoing request, and the server still
// receives the auth request.
assert.Equal(t, auth, "myuser:mypassword")
}

func TestWrapClient(t *testing.T) {
c := WrapClient(http.DefaultClient)
assert.Equal(t, c, http.DefaultClient)
Expand Down

0 comments on commit 8753841

Please sign in to comment.