Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test throttle rates #1679

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from addcorpus.models import Corpus
from addcorpus.serializers import CorpusJSONDefinitionSerializer
from es.models import Server
from rest_framework.test import APIClient


@pytest.fixture(autouse=True)
def media_dir(tmpdir, settings):
Expand Down Expand Up @@ -236,3 +238,19 @@ def json_mock_corpus(db, json_corpus_definition) -> Corpus:
corpus.configuration.save()

return corpus


@pytest.fixture
def drf_client():
return APIClient()


@pytest.fixture
def throttle_settings(settings):
settings.REST_FRAMEWORK.update({
'DEFAULT_THROTTLE_RATES': {
'password': '2/minute',
'registration': '2/minute',
}
})
return settings
47 changes: 47 additions & 0 deletions backend/users/tests/test_registration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from django.core import mail
import random
import re
import string

from django.core.cache import cache
from django.urls import reverse
from allauth.account.models import EmailAddress
from rest_framework import status


def test_register_verification(client, db, django_user_model):
Expand Down Expand Up @@ -46,3 +52,44 @@ def test_register_verification(client, db, django_user_model):
assert allauth_email.email == creds.get('email')
assert allauth_email.verified is True
assert allauth_email.primary is True


def test_register_throttling(client, throttle_settings):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The throttle rate from the fixture is not actually applied, although its value is read correctly in the test function

Perhaps an obvious suggestion, but did you try this?

Suggested change
def test_register_throttling(client, throttle_settings):
def test_register_throttling(throttle_settings, client):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now tried this, but the result is the same.

"""
Test that the ThrottledRegisterView returns a 429 error
after exceeding the allowed number of registration attempts.
"""
cache.clear() # Clear cache to reset rest_registration count
# client = drf_client
# Check throttle rate settings are applied
registration_rate = throttle_settings.REST_FRAMEWORK.get(
'DEFAULT_THROTTLE_RATES', {}).get('registration')
assert registration_rate == '2/minute', \
f"Expected registration throttle rate to be '2/minute', but got '{registration_rate}'."
assert throttle_settings.CACHES['default']['BACKEND'] == 'django.core.cache.backends.locmem.LocMemCache', \
f"Expected 'django.core.cache.backends.locmem.LocMemCache' for default cache backend, got {throttle_settings.CACHES['default']['BACKEND']}"

url = reverse('rest_register')

def generate_user_data():
"""Generate unique user data."""
random_str = ''.join(random.choices(string.digits, k=4))
return {
'username': f'testuser{random_str}',
'password1': 'Testpass123!',
'password2': 'Testpass123!',
'email': f'testuser{random_str}@example.com'
}

# This should use registration_rate + 1, but the rate we get from the fixture
# is not being applied for the actual throttling, it uses the rate from common_settings.
for i in range(1,7):
data = generate_user_data()
response = client.post(url, data, format='json')
# print(f"Request {i} status: {response.status_code}")

if i == 6:
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS, \
f"Expected 429, got {response.status_code}"
response_data = response.json()
assert 'detail' in response_data, "Response does not contain 'detail' key"