Skip to content

Commit

Permalink
add 2fa auth
Browse files Browse the repository at this point in the history
  • Loading branch information
eugapx committed Sep 6, 2023
1 parent b36ccc1 commit 072a731
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
22 changes: 19 additions & 3 deletions df_auth/drf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from django.conf import settings
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.models import AbstractUser, update_last_login
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import update_last_login
from django.db.models import Model
from django.utils.module_loading import import_string
from django_otp.models import Device
Expand All @@ -18,7 +19,11 @@

from ..settings import api_settings
from ..strategy import DRFStrategy
from ..utils import get_otp_device_choices, get_otp_device_models
from ..utils import (
get_otp_device_choices,
get_otp_device_models,
get_otp_devices,
)

User = get_user_model()

Expand Down Expand Up @@ -58,7 +63,7 @@ class TokenSerializer(serializers.Serializer):

class TokenCreateSerializer(TokenSerializer):
@classmethod
def get_token(cls, user: AbstractUser) -> None:
def get_token(cls, user: AbstractBaseUser) -> None:
return cls.token_class.for_user(user)

def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -82,6 +87,17 @@ def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
"""
attrs = {k: v for k, v in attrs.items() if v}
self.user = authenticate(**attrs, **self.context) # type: ignore

if self.user and getattr(self.user, "is_2fa_enabled", False):
devices = [d for d in get_otp_devices(self.user) if d.confirmed]
otp = attrs.get("otp")

if not any(d.verify_token(otp) for d in devices):
raise exceptions.AuthenticationFailed(
"2FA is required for this user.",
code="2fa_required",
)

return super().validate(attrs)

def get_fields(self) -> Dict[str, serializers.Field]:
Expand Down
7 changes: 2 additions & 5 deletions df_auth/drf/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..exceptions import DfAuthValidationError, WrongOTPError
from ..permissions import IsUnauthenticated
from ..settings import api_settings
from ..utils import get_otp_device_models
from ..utils import get_otp_device_models, get_otp_devices
from .serializers import (
OTPDeviceConfirmSerializer,
OTPDeviceSerializer,
Expand Down Expand Up @@ -103,10 +103,7 @@ class OtpDeviceViewSet(
permission_classes = (permissions.IsAuthenticated,)

def get_queryset(self) -> List[Device]:
devices = []
for DeviceModel in get_otp_device_models().values():
devices.extend(DeviceModel.objects.filter(user=self.request.user))
return devices
return get_otp_devices(self.request.user)

def get_device_model(self) -> Type[Device]:
device_type = self.request.GET.get("type")
Expand Down
8 changes: 8 additions & 0 deletions df_auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, List, Tuple, Type

from django.contrib.auth.base_user import AbstractBaseUser
from django.utils.module_loading import import_string
from django_otp.models import Device

Expand All @@ -15,3 +16,10 @@ def get_otp_device_models() -> Dict[str, Type[Device]]:

def get_otp_device_choices() -> List[Tuple[str, str]]:
return [(type_, type_) for type_ in api_settings.OTP_DEVICE_MODELS]


def get_otp_devices(user: AbstractBaseUser) -> List[Device]:
devices = []
for DeviceModel in get_otp_device_models().values():
devices.extend(DeviceModel.objects.filter(user=user))
return devices
38 changes: 38 additions & 0 deletions tests/test_app/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,41 @@ def test_obtain_token_by_username_and_password(self) -> None:
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotEqual(response.data.get("token", ""), "")


class TokenViewSet2FAAPITest(APITestCase):
def setUp(self) -> None:
self.client = APIClient()
self.user = User.objects.create_user(
username="testuser",
password="testpass",
email="test@te.st",
is_2fa_enabled=True,
)
self.device = EmailDevice.objects.create(
user=self.user, name=self.user.email, confirmed=True, email=self.user.email
)

def test_user_with_2fa_cannot_authorize_without_otp(self) -> None:
response = self.client.post(
reverse("df_api_drf:v1:auth:token-list"),
{
"username": self.user.username,
"password": "testpass",
},
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.data["errors"][0]["code"], "2fa_required")

def test_user_with_2fa_can_authorize_with_otp(self) -> None:
self.device.generate_challenge()
response = self.client.post(
reverse("df_api_drf:v1:auth:token-list"),
{
"username": self.user.username,
"password": "testpass",
"otp": self.device.token,
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotEqual(response.data.get("token", ""), "")

0 comments on commit 072a731

Please sign in to comment.