Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement rate limiting #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions bookmarks/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from rest_framework import viewsets
from rest_framework.throttling import UserRateThrottle
from rest_framework.response import Response
from rest_framework import status
from django.db.models.query import prefetch_related_objects
from my_bookmarks.settings import get_throttle_time
from bookmarks.models import Bookmark
from bookmarks.serializers import BookmarkSerializer

Expand All @@ -8,6 +13,7 @@ class BookmarkViewSet(viewsets.ModelViewSet):
A simple ViewSet for viewing and editing the
bookmarks associated with the user.
"""
throttle_classes = [UserRateThrottle]
serializer_class = BookmarkSerializer

def perform_create(self, serializer):
Expand All @@ -16,3 +22,49 @@ def perform_create(self, serializer):
def get_queryset(self):
user = self.request.user
return Bookmark.objects.filter(user=user)

def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
headers['X-Rate-Limit-Limit'] = get_throttle_time()
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())

page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)

serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)

queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)

return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT, headers={'X-Rate-Limit-Limit': get_throttle_time()})
76 changes: 64 additions & 12 deletions folders/views.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from rest_framework import viewsets
from rest_framework.decorators import api_view
from rest_framework.decorators import api_view, throttle_classes
from rest_framework.throttling import UserRateThrottle
from rest_framework.response import Response
from rest_framework import status
from django.db.models.query import prefetch_related_objects
from my_bookmarks.settings import get_throttle_time
from folders.models import Folder
from bookmarks.models import Bookmark
from folders.serializers import (
FolderSerializer,
ValidateInputFolderSerializer
)

# Create your views here.
class FolderViewSet(viewsets.ModelViewSet):
"""
A simple ViewSet for viewing and editing the
folders associated with the user.
"""
throttle_classes = [UserRateThrottle]
serializer_class = FolderSerializer

def perform_create(self, serializer):
Expand All @@ -24,8 +28,55 @@ def get_queryset(self):
user = self.request.user
return Folder.objects.filter(user=user)

def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
headers['X-Rate-Limit-Limit'] = get_throttle_time()
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())

page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)

serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)

queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)

return Response(serializer.data, headers={'X-Rate-Limit-Limit': get_throttle_time()})

def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT, headers={'X-Rate-Limit-Limit': get_throttle_time()})


@api_view(['POST'])
@throttle_classes([UserRateThrottle])
def add_bookmark_to_folder(request):
context = {'request':request}
serializer = ValidateInputFolderSerializer(data=request.data, context=context)
Expand All @@ -36,7 +87,7 @@ def add_bookmark_to_folder(request):

bookmark = Bookmark.objects.filter(user=request.user,id=bookmark_id).first()
if not bookmark:
return Response({"detail": "Bookmark not found."}, status=status.HTTP_404_NOT_FOUND)
return Response({"detail": "Bookmark not found."}, status=status.HTTP_404_NOT_FOUND, headers={'X-Rate-Limit-Limit': get_throttle_time()})

folder = Folder.objects.filter(user=request.user,name=folder_name).first()
if folder:
Expand All @@ -45,15 +96,16 @@ def add_bookmark_to_folder(request):
else:
folder.bookmark.add(bookmark)
serializer = FolderSerializer(folder, many=False)
return Response(serializer.data, status=status.HTTP_200_OK)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers={'X-Rate-Limit-Limit': get_throttle_time()})
else:
return Response({"detail": "Folder not found."}, status=status.HTTP_404_NOT_FOUND)
return Response({"detail": "Folder not found."}, status=status.HTTP_404_NOT_FOUND, headers={'X-Rate-Limit-Limit': get_throttle_time()})
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers={'X-Rate-Limit-Limit': get_throttle_time()})
except Exception as e:
return Response({"detail": e}, status=status.HTTP_400_BAD_REQUEST)
return Response({"detail": e}, status=status.HTTP_400_BAD_REQUEST, headers={'X-Rate-Limit-Limit': get_throttle_time()})

@api_view(['DELETE'])
@throttle_classes([UserRateThrottle])
def remove_bookmark_from_folder(request):
context = {'request':request}
serializer = ValidateInputFolderSerializer(data=request.data, context=context)
Expand All @@ -64,15 +116,15 @@ def remove_bookmark_from_folder(request):

bookmark = Bookmark.objects.filter(user=request.user,id=bookmark_id).first()
if not bookmark:
return Response({"detail": "Bookmark not found."}, status=status.HTTP_404_NOT_FOUND)
return Response({"detail": "Bookmark not found."}, status=status.HTTP_404_NOT_FOUND, headers={'X-Rate-Limit-Limit': get_throttle_time()})

folder = Folder.objects.filter(user=request.user,name=folder_name).first()
if folder:
folder.bookmark.remove(bookmark)
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_200_OK, headers={'X-Rate-Limit-Limit': get_throttle_time()})
else:
return Response({"detail": "Folder not found."}, status=status.HTTP_404_NOT_FOUND)
return Response({"detail": "Folder not found."}, status=status.HTTP_404_NOT_FOUND, headers={'X-Rate-Limit-Limit': get_throttle_time()})
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers={'X-Rate-Limit-Limit': get_throttle_time()})
except Exception as e:
return Response({"detail": e}, status=status.HTTP_400_BAD_REQUEST)
return Response({"detail": e}, status=status.HTTP_400_BAD_REQUEST, headers={'X-Rate-Limit-Limit': get_throttle_time()})
25 changes: 24 additions & 1 deletion my_bookmarks/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,31 @@
],
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework_simplejwt.authentication.JWTAuthentication',
)
),
'DEFAULT_THROTTLE_CLASSES': [
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
],
'DEFAULT_THROTTLE_RATES': {
'anon': '2/min',
'user': '4/min'
}
}

def get_throttle_time(path=REST_FRAMEWORK['DEFAULT_THROTTLE_RATES']['user']):
"""
get the time for a throttle class and convert it to seconds
"""
path = path.split('/')
if path[1] == 'sec' or path[1] == 's':
return path[0]
elif path[1] == 'min' or path[1] == 'm':
return str(int(path[0])*60)
elif path[1] == 'hour' or path[1] == 'h':
return str(int(path[0])*60*60)
elif path[1] == 'day' or path[1] == 'd':
return str(int(path[0])*60*60*24)

SIMPLE_JWT = {
'AUTH_HEADER_TYPES': ('JWT',),
'ACCESS_TOKEN_LIFETIME': timedelta(minutes=60),
Expand Down
3 changes: 3 additions & 0 deletions users/views.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from rest_framework import generics
from rest_framework.permissions import AllowAny
from rest_framework.throttling import AnonRateThrottle
from djoser import signals
from djoser.conf import settings
from djoser.compat import get_user_email
from templated_mail.mail import BaseEmailMessage
from users.models import User
from users.serializers import UserSerializer
from my_bookmarks.settings import get_throttle_time

# Create your views here.

class ConfirmationEmail(BaseEmailMessage):
template_name = "confirmation_email.html"

class UserCreate(generics.CreateAPIView):
throttle_classes = [AnonRateThrottle]
queryset = User.objects.all()
serializer_class = UserSerializer
permission_classes = [AllowAny]
Expand Down