Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #16024: Change attr_type from list to str for MultipleChoiceFilter #17638

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion netbox/dcim/filtersets.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM

class Meta:
model = Location
fields = ('id', 'name', 'slug', 'status', 'facility', 'description')
fields = ('id', 'name', 'slug', 'facility', 'description')

def search(self, queryset, name, value):
if not value.strip():
Expand Down
8 changes: 4 additions & 4 deletions netbox/netbox/graphql/filter_mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial, partialmethod, wraps
from functools import partialmethod
from typing import List

import django_filters
import strawberry
import strawberry_django
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.core.exceptions import FieldDoesNotExist
from strawberry import auto

from ipam.fields import ASNField
from netbox.graphql.scalars import BigInt
from utilities.fields import ColorField, CounterCacheField
Expand Down Expand Up @@ -108,8 +109,7 @@ def map_strawberry_type(field):
elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter):
pass
elif issubclass(type(field), django_filters.MultipleChoiceFilter):
should_create_function = True
attr_type = List[str] | None
attr_type = str | None
elif issubclass(type(field), django_filters.TypedChoiceFilter):
pass
elif issubclass(type(field), django_filters.ChoiceFilter):
Expand Down
41 changes: 37 additions & 4 deletions netbox/netbox/tests/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from rest_framework import status

from core.models import ObjectType
from dcim.choices import LocationStatusChoices
from dcim.models import Site, Location
from ipam.models import ASN, RIR
from users.models import ObjectPermission
from utilities.testing import disable_warnings, APITestCase, TestCase

Expand Down Expand Up @@ -53,10 +53,27 @@ def test_graphql_filter_objects(self):
sites = (
Site(name='Site 1', slug='site-1'),
Site(name='Site 2', slug='site-2'),
Site(name='Site 3', slug='site-3'),
)
Site.objects.bulk_create(sites)
Location.objects.create(site=sites[0], name='Location 1', slug='location-1'),
Location.objects.create(site=sites[1], name='Location 2', slug='location-2'),
Location.objects.create(
site=sites[0],
name='Location 1',
slug='location-1',
status=LocationStatusChoices.STATUS_PLANNED
),
Location.objects.create(
site=sites[1],
name='Location 2',
slug='location-2',
status=LocationStatusChoices.STATUS_STAGING
),
Location.objects.create(
site=sites[1],
name='Location 3',
slug='location-3',
status=LocationStatusChoices.STATUS_ACTIVE
),

# Add object-level permission
obj_perm = ObjectPermission(
Expand All @@ -68,8 +85,9 @@ def test_graphql_filter_objects(self):
obj_perm.object_types.add(ObjectType.objects.get_for_model(Location))
obj_perm.object_types.add(ObjectType.objects.get_for_model(Site))

# A valid request should return the filtered list
url = reverse('graphql')

# A valid request should return the filtered list
query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}'
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
Expand All @@ -78,6 +96,21 @@ def test_graphql_filter_objects(self):
self.assertEqual(len(data['data']['location_list']), 1)
self.assertIsNotNone(data['data']['location_list'][0]['site'])

# Test OR logic
query = """{
location_list( filters: {
status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\",
OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"}
}) {
id site {id}
}
}"""
response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertEqual(len(data['data']['location_list']), 2)

# An invalid request should return an empty list
query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID
response = self.client.post(url, data={'query': query}, format="json", **self.header)
Expand Down