Skip to content

Commit

Permalink
Add check in serailizer instead of blacklist mixin.
Browse files Browse the repository at this point in the history
  • Loading branch information
ajay09 committed Jun 12, 2024
1 parent 293040e commit d698cd0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
14 changes: 13 additions & 1 deletion rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.contrib.auth.models import AbstractBaseUser, update_last_login
from django.utils.translation import gettext_lazy as _
from rest_framework import exceptions, serializers
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ValidationError, AuthenticationFailed

from .models import TokenUser
from .settings import api_settings
Expand Down Expand Up @@ -104,9 +104,21 @@ class TokenRefreshSerializer(serializers.Serializer):
access = serializers.CharField(read_only=True)
token_class = RefreshToken

default_error_messages = {
"no_active_account": _("No active account found with the given credentials")
}

def validate(self, attrs: Dict[str, Any]) -> Dict[str, str]:
refresh = self.token_class(attrs["refresh"])

user_id = refresh.payload.get(api_settings.USER_ID_CLAIM, None)
if user_id and (user := get_user_model().objects.get(**{api_settings.USER_ID_FIELD: user_id})):
if not api_settings.USER_AUTHENTICATION_RULE(user):
raise AuthenticationFailed(
self.error_messages["no_active_account"],
"no_active_account",
)

data = {"access": str(refresh.access_token)}

if api_settings.ROTATE_REFRESH_TOKENS:
Expand Down
12 changes: 0 additions & 12 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,8 @@ class BlacklistMixin(Generic[T]):
def verify(self, *args, **kwargs) -> None:
self.check_blacklist()

self.check_user_active()

super().verify(*args, **kwargs) # type: ignore

def check_user_active(self):
user_id = self.payload.get(api_settings.USER_ID_CLAIM, None)
if (
user_id
and not get_user_model()
.objects.get(**{api_settings.USER_ID_FIELD: user_id})
.is_active
):
raise TokenError(_("User is inactive"))

def check_blacklist(self) -> None:
"""
Checks if this token is present in the token blacklist. Raises
Expand Down

0 comments on commit d698cd0

Please sign in to comment.