Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws/request: Fix bug in request presign creating invalide URL #1624

Merged
merged 4 commits into from
Nov 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions aws/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ const (
// during body reads.
ErrCodeResponseTimeout = "ResponseTimeout"

// ErrCodeInvalidPresignExpire is returned when the expire time provided to
// presign is invalid
ErrCodeInvalidPresignExpire = "InvalidPresignExpireError"

// CanceledErrorCode is the error code that will be returned by an
// API request that was canceled. Requests given a aws.Context may
// return this error when canceled.
Expand All @@ -42,7 +46,6 @@ type Request struct {

Retryer
Time time.Time
ExpireTime time.Duration
Operation *Operation
HTTPRequest *http.Request
HTTPResponse *http.Response
Expand All @@ -60,6 +63,11 @@ type Request struct {
LastSignedAt time.Time
DisableFollowRedirects bool

// A value greater than 0 instructs the request to be signed as Presigned URL
// You should not set this field directly. Instead use Request's
// Presign or PresignRequest methods.
ExpireTime time.Duration

context aws.Context

built bool
Expand Down Expand Up @@ -250,40 +258,59 @@ func (r *Request) SetReaderBody(reader io.ReadSeeker) {

// Presign returns the request's signed URL. Error will be returned
// if the signing fails.
func (r *Request) Presign(expireTime time.Duration) (string, error) {
r.ExpireTime = expireTime
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
func (r *Request) Presign(expire time.Duration) (string, error) {
r = r.copy()

// Presign requires all headers be hoisted. There is no way to retrieve
// the signed headers not hoisted without this. Making the presigned URL
// useless.
r.NotHoist = false

if r.Operation.BeforePresignFn != nil {
r = r.copy()
err := r.Operation.BeforePresignFn(r)
if err != nil {
return "", err
}
}

r.Sign()
if r.Error != nil {
return "", r.Error
}
return r.HTTPRequest.URL.String(), nil
u, _, err := getPresignedURL(r, expire)
return u, err
}

// PresignRequest behaves just like presign, with the addition of returning a
// set of headers that were signed.
//
// It is invalid to create a presigned URL with a expire duration 0 or less. An
// error is returned if expire duration is 0 or less.
//
// Returns the URL string for the API operation with signature in the query string,
// and the HTTP headers that were included in the signature. These headers must
// be included in any HTTP request made with the presigned URL.
//
// To prevent hoisting any headers to the query string set NotHoist to true on
// this Request value prior to calling PresignRequest.
func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, error) {
r.ExpireTime = expireTime
r.Sign()
if r.Error != nil {
return "", nil, r.Error
func (r *Request) PresignRequest(expire time.Duration) (string, http.Header, error) {
r = r.copy()
return getPresignedURL(r, expire)
}

func getPresignedURL(r *Request, expire time.Duration) (string, http.Header, error) {
if expire <= 0 {
return "", nil, awserr.New(
ErrCodeInvalidPresignExpire,
"presigned URL requires an expire duration greater than 0",
nil,
)
}

r.ExpireTime = expire

if r.Operation.BeforePresignFn != nil {
if err := r.Operation.BeforePresignFn(r); err != nil {
return "", nil, err
}
}

if err := r.Sign(); err != nil {
return "", nil, err
}

return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil
}

Expand Down
117 changes: 117 additions & 0 deletions aws/request/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"runtime"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -842,3 +844,118 @@ func TestRequest_TemporaryRetry(t *testing.T) {
t.Errorf("expect temporary error, was not")
}
}

func TestRequest_Presign(t *testing.T) {
presign := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
u, err := r.Presign(expire)
return u, nil, err
}
presignRequest := func(r *request.Request, expire time.Duration) (string, http.Header, error) {
return r.PresignRequest(expire)
}
mustParseURL := func(v string) *url.URL {
u, err := url.Parse(v)
if err != nil {
panic(err)
}
return u
}

cases := []struct {
Expire time.Duration
PresignFn func(*request.Request, time.Duration) (string, http.Header, error)
SignerFn func(*request.Request)
URL string
Header http.Header
Err string
}{
{
PresignFn: presign,
Err: request.ErrCodeInvalidPresignExpire,
},
{
PresignFn: presignRequest,
Err: request.ErrCodeInvalidPresignExpire,
},
{
Expire: -1,
PresignFn: presign,
Err: request.ErrCodeInvalidPresignExpire,
},
{
// Presign clear NotHoist
Expire: 1 * time.Minute,
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
r.NotHoist = true
return presign(r, dur)
},
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
fmt.Println("url", r.HTTPRequest.URL.String())
if r.NotHoist {
r.Error = fmt.Errorf("expect NotHoist to be cleared")
}
},
URL: "https://endpoint/presignedURL",
},
{
// PresignRequest does not clear NotHoist
Expire: 1 * time.Minute,
PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) {
r.NotHoist = true
return presignRequest(r, dur)
},
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
if !r.NotHoist {
r.Error = fmt.Errorf("expect NotHoist not to be cleared")
}
},
URL: "https://endpoint/presignedURL",
},
{
// PresignRequest returns signed headers
Expire: 1 * time.Minute,
PresignFn: presignRequest,
SignerFn: func(r *request.Request) {
r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL")
r.HTTPRequest.Header.Set("UnsigndHeader", "abc")
r.SignedHeaderVals = http.Header{
"X-Amzn-Header": []string{"abc", "123"},
"X-Amzn-Header2": []string{"efg", "456"},
}
},
URL: "https://endpoint/presignedURL",
Header: http.Header{
"X-Amzn-Header": []string{"abc", "123"},
"X-Amzn-Header2": []string{"efg", "456"},
},
},
}

svc := awstesting.NewClient()
svc.Handlers.Clear()
for i, c := range cases {
req := svc.NewRequest(&request.Operation{
Name: "name", HTTPMethod: "GET", HTTPPath: "/path",
}, &struct{}{}, &struct{}{})
req.Handlers.Sign.PushBack(c.SignerFn)

u, h, err := c.PresignFn(req, c.Expire)
if len(c.Err) != 0 {
if e, a := c.Err, err.Error(); !strings.Contains(a, e) {
t.Errorf("%d, expect %v to be in %v", i, e, a)
}
continue
} else if err != nil {
t.Errorf("%d, expect no error, got %v", i, err)
continue
}
if e, a := c.URL, u; e != a {
t.Errorf("%d, expect %v URL, got %v", i, e, a)
}
if e, a := c.Header, h; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v header got %v", i, e, a)
}
}
}
10 changes: 5 additions & 5 deletions aws/signer/v4/v4.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ type signingCtx struct {
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, 0, signTime)
return v4.signWithBody(r, body, service, region, 0, false, signTime)
}

// Presign signs AWS v4 requests with the provided body, service name, region
Expand Down Expand Up @@ -302,10 +302,10 @@ func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region strin
// presigned request's signature you can set the "X-Amz-Content-Sha256"
// HTTP header and that will be included in the request's signature.
func (v4 Signer) Presign(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, exp, signTime)
return v4.signWithBody(r, body, service, region, exp, true, signTime)
}

func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, signTime time.Time) (http.Header, error) {
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, exp time.Duration, isPresign bool, signTime time.Time) (http.Header, error) {
currentTimeFn := v4.currentTimeFn
if currentTimeFn == nil {
currentTimeFn = time.Now
Expand All @@ -317,7 +317,7 @@ func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, regi
Query: r.URL.Query(),
Time: signTime,
ExpireTime: exp,
isPresign: exp != 0,
isPresign: isPresign,
ServiceName: service,
Region: region,
DisableURIPathEscaping: v4.DisableURIPathEscaping,
Expand Down Expand Up @@ -467,7 +467,7 @@ func signSDKRequestWithCurrTime(req *request.Request, curTimeFn func() time.Time
}

signedHeaders, err := v4.signWithBody(req.HTTPRequest, req.GetBody(),
name, region, req.ExpireTime, signingTime,
name, region, req.ExpireTime, req.ExpireTime > 0, signingTime,
)
if err != nil {
req.Error = err
Expand Down