Skip to content

Add account limit exception #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 34 additions & 56 deletions src/shipchain_common/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from datetime import timedelta
from django.conf import settings
from django.core.cache import cache
from django.utils.functional import cached_property
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.permissions import BasePermission
from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.models import TokenUser
from rest_framework_simplejwt.utils import aware_utcnow

from .utils import parse_dn

Expand Down Expand Up @@ -89,43 +91,18 @@ def set_password(self, raw_password):
def check_password(self, raw_password):
raise NotImplementedError('Token users have no DB representation')

def _get_permission_cache_key(self):
"""
Build a unique cache key for this specific JWT.
If no `jti`, `at_hash`, or `sub` and `exp`, then return None
"""
unique_key = self.token.get('jti')

if not unique_key:
unique_key = self.token.get('at_hash')

if not unique_key:
sub = self.token.get("sub")
exp = self.token.get("exp")
if sub and exp:
unique_key = f'{sub}.{exp}'

return unique_key

def _get_permission_cache_life(self):
"""
Determine cache life from JWT. If exp or iat are not present, or
if calculation results in an invalid life, return the fallback_life
"""
fallback_life = 300

exp = self.token.get("exp")
iat = self.token.get("iat")
@cached_property
def _permissions(self):
features = self.token.get('features')
if not features:
return []

if not exp or not iat:
return fallback_life
permissions = []
for feature in features:
for permission in features[feature]:
permissions.append(f'{feature}.{permission}')

life = exp - iat

if not life or life <= 0:
return fallback_life

return life
return permissions

def get_all_permissions(self, obj=None):
"""
Expand All @@ -134,26 +111,16 @@ def get_all_permissions(self, obj=None):
This prevents re-parsing the permissions over the lifetime of this token
as they will not change until a new token is received
"""
permissions = None
unique_key = self._get_permission_cache_key()

if unique_key:
permissions = cache.get(unique_key)

if not permissions:
features = self.token.get('features')
if not features:
return []
try:
# Check token expiration, invalidate cache if expired
self.token.check_exp(current_time=aware_utcnow() + timedelta(seconds=30))
except TokenError:
try:
del self._permissions
except AttributeError:
pass

permissions = []
for feature in features:
for permission in features[feature]:
permissions.append(f'{feature}.{permission}')

if unique_key:
cache.set(unique_key, permissions, self._get_permission_cache_life())

return permissions
return self._permissions

def has_perm(self, perm, obj=None):
"""
Expand All @@ -166,3 +133,14 @@ def has_perms(self, perm_list, obj=None):
Validate perm_list is in token feature permissions
"""
return all(self.has_perm(perm, obj) for perm in perm_list)

@property
def limits(self):
return self.token.get('limits', {})

def get_limit(self, entity, name):
limit = None
entity = self.limits.get(entity)
if entity:
limit = entity.get(name)
return limit
6 changes: 6 additions & 0 deletions src/shipchain_common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,9 @@ class URLShortenerError(Custom500Error):
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = 'URL Shortener Error.'
default_code = 'server_error'


class AccountLimitReached(APIException):
status_code = status.HTTP_402_PAYMENT_REQUIRED
default_detail = 'Request denied due to the restrictions of your current billing tier.'
default_code = 'account_limit_reached'
5 changes: 5 additions & 0 deletions src/shipchain_common/test_utils/json_asserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ def assert_401(response, error='Authentication credentials were not provided', v
assert response.status_code == status.HTTP_401_UNAUTHORIZED, f'status_code {response.status_code} != 401'
response_has_error(response, error, vnd=vnd)

def assert_402(response, error='Request denied due to the restrictions of your current billing tier.', vnd=True):
assert response is not None
assert response.status_code == status.HTTP_402_PAYMENT_REQUIRED, f'status_code {response.status_code} != 402'
response_has_error(response, error, vnd=vnd)

def assert_403(response, error='You do not have permission to perform this action', vnd=True):
assert response is not None
Expand Down Expand Up @@ -459,6 +463,7 @@ class AssertionHelper:

HTTP_400 = assert_400
HTTP_401 = assert_401
HTTP_402 = assert_402
HTTP_403 = assert_403
HTTP_404 = assert_404
HTTP_405 = assert_405
Expand Down
48 changes: 0 additions & 48 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,54 +152,6 @@ def test_lambda_auth_requires_header(lambda_request):
assert lambda_request.has_permission(request, {})


def test_token_user_jti_cache_key():
"""By default, the jti is included in get_jwt and is used as cache key"""
jwt = get_jwt()
token = UntypedToken(jwt)
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_key() == token_user.token.get('jti')


def test_token_user_at_hash_cache_key():
"""If no jti is included in get_jwt then use at_hash as cache key if exists"""
jwt = get_jwt(jti=0, at_hash=uuid4().hex)
token = UntypedToken(jwt)
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_key() == token_user.token.get('at_hash')


def test_token_user_sub_exp_cache_key():
"""If no jti or at_hash is included in get_jwt then use {sub}.{exp} as cache key"""
jwt = get_jwt(jti=0, sub=uuid4().hex)
token = UntypedToken(jwt)
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_key() == f'{token_user.token.get("sub")}.{token_user.token.get("exp")}'


def test_token_user_cache_life():
jwt = get_jwt()
token = UntypedToken(jwt)
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_life() == 300


def test_token_user_cache_calculated_life():
iat = datetime_to_epoch(aware_utcnow())
jwt = get_jwt(exp=iat+15, iat=iat)
token = UntypedToken(jwt)
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_life() == 15


def test_token_user_cache_fallback_life():
iat = datetime_to_epoch(aware_utcnow())
jwt = get_jwt(exp=iat+15, iat=iat)
token = UntypedToken(jwt)
token.payload['iat'] = None
token_user = PermissionedTokenUser(token)
assert token_user._get_permission_cache_life() == 300


@pytest.fixture
def one_feature():
"""Returns feature object response in token, and list of feature permissions"""
Expand Down