Skip to content

Commit

Permalink
Add ability to cancel device code flow auth (#474)
Browse files Browse the repository at this point in the history
* Add ability to cancel device code flow auth

Added functions for device code flow that take a context.
Enforce Go modules mode in CI.

* remove vanity import to fix CI

* revert enabling of modules in CI (will handle in separate PR)
  • Loading branch information
jhendrixMSFT authored Oct 14, 2019
1 parent 7fcf7bf commit 740293c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 6 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# CHANGELOG

## v13.2.0

### New Features

- Added the following functions to replace their versions that don't take a context.
- `adal.InitiateDeviceAuthWithContext()`
- `adal.CheckForUserCompletionWithContext()`
- `adal.WaitForUserCompletionWithContext()`

## v13.1.0

### New Features
Expand Down
35 changes: 31 additions & 4 deletions autorest/adal/devicetoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package adal
*/

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -101,7 +102,14 @@ type deviceToken struct {

// InitiateDeviceAuth initiates a device auth flow. It returns a DeviceCode
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
// Deprecated: use InitiateDeviceAuthWithContext() instead.
func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
return InitiateDeviceAuthWithContext(context.Background(), sender, oauthConfig, clientID, resource)
}

// InitiateDeviceAuthWithContext initiates a device auth flow. It returns a DeviceCode
// that can be used with CheckForUserCompletion or WaitForUserCompletion.
func InitiateDeviceAuthWithContext(ctx context.Context, sender Sender, oauthConfig OAuthConfig, clientID, resource string) (*DeviceCode, error) {
v := url.Values{
"client_id": []string{clientID},
"resource": []string{resource},
Expand All @@ -117,7 +125,7 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour

req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
resp, err := sender.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errCodeSendingFails, err.Error())
}
Expand Down Expand Up @@ -151,7 +159,14 @@ func InitiateDeviceAuth(sender Sender, oauthConfig OAuthConfig, clientID, resour

// CheckForUserCompletion takes a DeviceCode and checks with the Azure AD OAuth endpoint
// to see if the device flow has: been completed, timed out, or otherwise failed
// Deprecated: use CheckForUserCompletionWithContext() instead.
func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return CheckForUserCompletionWithContext(context.Background(), sender, code)
}

// CheckForUserCompletionWithContext takes a DeviceCode and checks with the Azure AD OAuth endpoint
// to see if the device flow has: been completed, timed out, or otherwise failed
func CheckForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
v := url.Values{
"client_id": []string{code.ClientID},
"code": []string{*code.DeviceCode},
Expand All @@ -169,7 +184,7 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {

req.ContentLength = int64(len(s))
req.Header.Set(contentType, mimeTypeFormPost)
resp, err := sender.Do(req)
resp, err := sender.Do(req.WithContext(ctx))
if err != nil {
return nil, fmt.Errorf("%s %s: %s", logPrefix, errTokenSendingFails, err.Error())
}
Expand Down Expand Up @@ -213,12 +228,19 @@ func CheckForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {

// WaitForUserCompletion calls CheckForUserCompletion repeatedly until a token is granted or an error state occurs.
// This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
// Deprecated: use WaitForUserCompletionWithContext() instead.
func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return WaitForUserCompletionWithContext(context.Background(), sender, code)
}

// WaitForUserCompletionWithContext calls CheckForUserCompletion repeatedly until a token is granted or an error
// state occurs. This prevents the user from looping and checking against 'ErrDeviceAuthorizationPending'.
func WaitForUserCompletionWithContext(ctx context.Context, sender Sender, code *DeviceCode) (*Token, error) {
intervalDuration := time.Duration(*code.Interval) * time.Second
waitDuration := intervalDuration

for {
token, err := CheckForUserCompletion(sender, code)
token, err := CheckForUserCompletionWithContext(ctx, sender, code)

if err == nil {
return token, nil
Expand All @@ -237,6 +259,11 @@ func WaitForUserCompletion(sender Sender, code *DeviceCode) (*Token, error) {
return nil, fmt.Errorf("%s Error waiting for user to complete device flow. Server told us to slow_down too much", logPrefix)
}

time.Sleep(waitDuration)
select {
case <-time.After(waitDuration):
// noop
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
14 changes: 14 additions & 0 deletions autorest/adal/devicetoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ package adal
// limitations under the License.

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"time"

"github.com/Azure/go-autorest/autorest/mocks"
)
Expand Down Expand Up @@ -328,3 +330,15 @@ func TestDeviceTokenReturnsErrorIfTokenEmptyAndStatusOK(t *testing.T) {
t.Fatalf("response body was left open!")
}
}

func TestWaitForUserCompletionWithContext(t *testing.T) {
sender := SenderFunc(func(*http.Request) (*http.Response, error) {
return mocks.NewResponseWithContent(`{"error":"authorization_pending"}`), nil
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := WaitForUserCompletionWithContext(ctx, sender, deviceCode())
if err != context.DeadlineExceeded {
t.Fatalf("adal: got wrong error expected(%s) actual(%s)", context.DeadlineExceeded.Error(), err.Error())
}
}
2 changes: 1 addition & 1 deletion autorest/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"runtime"
)

const number = "v13.1.0"
const number = "v13.2.0"

var (
userAgent = fmt.Sprintf("Go/%s (%s-%s) go-autorest/%s",
Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ steps:
go get github.com/jstemmer/go-junit-report
go get github.com/axw/gocov/gocov
go get github.com/AlekSi/gocov-xml
go get -u gopkg.in/matm/v1/gocov-html
go get -u github.com/matm/gocov-html
workingDirectory: '$(sdkPath)'
displayName: 'Install Dependencies'
- script: |
Expand Down

0 comments on commit 740293c

Please sign in to comment.