Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion oauth2_provider/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
class AllowForm(forms.Form):
allow = forms.BooleanField(required=False)
redirect_uri = forms.CharField(widget=forms.HiddenInput())
scopes = forms.CharField(required=False, widget=forms.HiddenInput())
scope = forms.CharField(required=False, widget=forms.HiddenInput())
client_id = forms.CharField(widget=forms.HiddenInput())
state = forms.CharField(required=False, widget=forms.HiddenInput())
response_type = forms.CharField(widget=forms.HiddenInput())

def __init__(self, *args, **kwargs):
data = kwargs.get('data')
# backwards compatible support for plural `scopes` query parameter
if data and 'scopes' in data:
data['scope'] = data['scopes']
return super(AllowForm, self).__init__(*args, **kwargs)


class RegistrationForm(forms.ModelForm):
"""
Expand Down
14 changes: 7 additions & 7 deletions oauth2_provider/tests/test_authorization_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_pre_auth_valid_client(self):
form = response.context["form"]
self.assertEqual(form['redirect_uri'].value(), "http://example.it")
self.assertEqual(form['state'].value(), "random_state_string")
self.assertEqual(form['scopes'].value(), "read write")
self.assertEqual(form['scope'].value(), "read write")
self.assertEqual(form['client_id'].value(), self.application.client_id)

def test_pre_auth_approval_prompt(self):
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_code_post_auth_allow(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand All @@ -198,7 +198,7 @@ def test_code_post_auth_deny(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': False,
Expand All @@ -217,7 +217,7 @@ def test_code_post_auth_bad_responsetype(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'UNKNOWN',
'allow': True,
Expand All @@ -236,7 +236,7 @@ def test_code_post_auth_forbidden_redirect_uri(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://forbidden.it',
'response_type': 'code',
'allow': True,
Expand All @@ -254,7 +254,7 @@ def get_auth(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -559,7 +559,7 @@ def test_resource_access_allowed(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down
8 changes: 4 additions & 4 deletions oauth2_provider/tests/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_pre_auth_valid_client(self):
form = response.context["form"]
self.assertEqual(form['redirect_uri'].value(), "http://example.it")
self.assertEqual(form['state'].value(), "random_state_string")
self.assertEqual(form['scopes'].value(), "read write")
self.assertEqual(form['scope'].value(), "read write")
self.assertEqual(form['client_id'].value(), self.application.client_id)

def test_pre_auth_invalid_client(self):
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_post_auth_allow(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'token',
'allow': True,
Expand All @@ -149,7 +149,7 @@ def test_token_post_auth_deny(self):
form_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'token',
'allow': False,
Expand All @@ -168,7 +168,7 @@ def test_resource_access_allowed(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write',
'scope': 'read write',
'redirect_uri': 'http://example.it',
'response_type': 'token',
'allow': True,
Expand Down
70 changes: 62 additions & 8 deletions oauth2_provider/tests/test_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.urlresolvers import reverse

from .test_utils import TestCaseUtils
from ..compat import urlparse, parse_qs, get_user_model
from ..compat import urlparse, parse_qs, get_user_model, urlencode
from ..models import get_application_model, Grant, AccessToken
from ..settings import oauth2_settings
from ..views import ScopedProtectedResourceView, ReadWriteScopedResourceView
Expand Down Expand Up @@ -64,6 +64,60 @@ def tearDown(self):
self.dev_user.delete()


class TestScopesQueryParameterBackwardsCompatibility(BaseTest):
def setUp(self):
super(TestScopesQueryParameterBackwardsCompatibility, self).setUp()
oauth2_settings._SCOPES = ['read', 'write']

def test_scopes_query_parameter_is_supported_on_post(self):
"""
Tests support for plural `scopes` query parameter on POST requests.

"""
self.client.login(username="test_user", password="123456")

# retrieve a valid authorization code
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write', # using plural `scopes`
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
}
response = self.client.post(reverse('oauth2_provider:authorize'), data=authcode_data)
query_dict = parse_qs(urlparse(response['Location']).query)
authorization_code = query_dict['code'].pop()

grant = Grant.objects.get(code=authorization_code)
self.assertEqual(grant.scope, "read write")

def test_scopes_query_parameter_is_supported_on_get(self):
"""
Tests support for plural `scopes` query parameter on GET requests.

"""
self.client.login(username="test_user", password="123456")

query_string = urlencode({
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'read write', # using plural `scopes`
'redirect_uri': 'http://example.it',
'response_type': 'code',
})
url = "{url}?{qs}".format(url=reverse('oauth2_provider:authorize'), qs=query_string)

response = self.client.get(url)
self.assertEqual(response.status_code, 200)

# check form is in context
self.assertIn("form", response.context)

form = response.context["form"]
self.assertEqual(form['scope'].value(), "read write")


class TestScopesSave(BaseTest):
def test_scopes_saved_in_grant(self):
"""
Expand All @@ -75,7 +129,7 @@ def test_scopes_saved_in_grant(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope1 scope2',
'scope': 'scope1 scope2',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand All @@ -97,7 +151,7 @@ def test_scopes_save_in_access_token(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope1 scope2',
'scope': 'scope1 scope2',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -133,7 +187,7 @@ def test_scopes_protection_valid(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope1 scope2',
'scope': 'scope1 scope2',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -175,7 +229,7 @@ def test_scopes_protection_fail(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope2',
'scope': 'scope2',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -217,7 +271,7 @@ def test_multi_scope_fail(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope1 scope3',
'scope': 'scope1 scope3',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -259,7 +313,7 @@ def test_multi_scope_valid(self):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': 'scope1 scope2',
'scope': 'scope1 scope2',
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down Expand Up @@ -300,7 +354,7 @@ def get_access_token(self, scopes):
authcode_data = {
'client_id': self.application.client_id,
'state': 'random_state_string',
'scopes': scopes,
'scope': scopes,
'redirect_uri': 'http://example.it',
'response_type': 'code',
'allow': True,
Expand Down
6 changes: 3 additions & 3 deletions oauth2_provider/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ class AuthorizationView(BaseAuthorizationView, FormView):

def get_initial(self):
# TODO: move this scopes conversion from and to string into a utils function
scopes = self.oauth2_data.get('scopes', [])
scopes = self.oauth2_data.get('scope', self.oauth2_data.get('scopes', []))
initial_data = {
'redirect_uri': self.oauth2_data.get('redirect_uri', None),
'scopes': ' '.join(scopes),
'scope': ' '.join(scopes),
'client_id': self.oauth2_data.get('client_id', None),
'state': self.oauth2_data.get('state', None),
'response_type': self.oauth2_data.get('response_type', None),
Expand All @@ -90,7 +90,7 @@ def form_valid(self, form):
'state': form.cleaned_data.get('state', None),
}

scopes = form.cleaned_data.get('scopes')
scopes = form.cleaned_data.get('scope')
allow = form.cleaned_data.get('allow')
uri, headers, body, status = self.create_authorization_response(
request=self.request, scopes=scopes, credentials=credentials, allow=allow)
Expand Down