Skip to content

Commit

Permalink
Merge pull request #5179 from mozilla/check-api-queries-mpp-3927
Browse files Browse the repository at this point in the history
MPP-3927: Pre-fetch profile for domain address
  • Loading branch information
jwhitlock authored Nov 12, 2024
2 parents 64e44e1 + d281b0e commit a0d7ffd
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
74 changes: 73 additions & 1 deletion api/tests/emails_views_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest
from model_bakery import baker
from pytest_django.fixtures import SettingsWrapper
from pytest_django.fixtures import DjangoAssertNumQueries, SettingsWrapper
from rest_framework.exceptions import MethodNotAllowed, NotAuthenticated
from rest_framework.test import APIClient
from waffle.testutils import override_flag
Expand All @@ -21,6 +21,27 @@
)


@pytest.fixture
def settings_without_sqlcommenter(settings: SettingsWrapper) -> SettingsWrapper:
"""
Remove the sqlcommenter from the middleware.
For sqlite, it injects two queries into the recorded queries. The
first is a plain string, the second is the expected dictionary format.
This breaks the tests using django_assert_num_queries
First query: "SELECT id, ...
Second query: {"sql": "SELECT id, ..."}A
"""
try:
settings.MIDDLEWARE.remove(
"google.cloud.sqlcommenter.django.middleware.SqlCommenter"
)
except ValueError:
# sqlcommenter not available for Python 3.12 and later
pass
return settings


def test_post_domainaddress_success(
prem_api_client: APIClient, premium_user: User, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down Expand Up @@ -387,6 +408,33 @@ def test_patch_domainaddress_addr_with_id_fails(
assert get_glean_event(caplog) is None


@pytest.mark.parametrize("address_count", (0, 1, 2, 5))
def test_get_domainaddress(
prem_api_client: APIClient,
premium_user: User,
django_assert_num_queries: DjangoAssertNumQueries,
settings_without_sqlcommenter: SettingsWrapper,
address_count: int,
) -> None:
"""
A GET request makes 1 request for no results, and 3 requests for any results.
"""
address_qs = DomainAddress.objects.filter(user=premium_user)
count = address_qs.count()
assert count <= address_count
while count < address_count:
DomainAddress.objects.create(user=premium_user, address=f"address-{count}")
count = address_qs.count()

url = reverse("domainaddress-list")
expected_queries = 3 if address_count else 1
with django_assert_num_queries(expected_queries):
response = prem_api_client.get(url)
data = response.json()
assert response.status_code == 200
assert len(data) == address_count


def test_delete_domainaddress(
prem_api_client: APIClient, premium_user: User, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down Expand Up @@ -749,6 +797,30 @@ def test_delete_randomaddress(
assert event == expected_event


@pytest.mark.parametrize("address_count", (0, 1, 2))
def test_get_relayaddress(
free_api_client: APIClient,
free_user: User,
django_assert_num_queries: DjangoAssertNumQueries,
settings_without_sqlcommenter: SettingsWrapper,
address_count: int,
) -> None:
"""A GET request should make 1 query, no matter the address count."""
address_qs = RelayAddress.objects.filter(user=free_user)
count = address_qs.count()
assert count <= address_count
while count < address_count:
RelayAddress.objects.create(user=free_user)
count = address_qs.count()

url = reverse("relayaddress-list")
with django_assert_num_queries(1):
response = free_api_client.get(url)
data = response.json()
assert response.status_code == 200
assert len(data) == address_count


def test_first_forwarded_email_unauth(client: Client) -> None:
response = client.post("/api/v1/first-forwarded-email/")
assert response.status_code == 401
Expand Down
4 changes: 3 additions & 1 deletion api/views/emails.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ class DomainAddressViewSet(AddressViewSet[DomainAddress]):

def get_queryset(self) -> QuerySet[DomainAddress]:
if isinstance(self.request.user, User):
return DomainAddress.objects.filter(user=self.request.user)
return DomainAddress.objects.filter(
user=self.request.user
).prefetch_related("user", "user__profile")
return DomainAddress.objects.none()


Expand Down

0 comments on commit a0d7ffd

Please sign in to comment.