diff --git a/netbox/dcim/filtersets.py b/netbox/dcim/filtersets.py index 6517aadb45b..5a101e739b7 100644 --- a/netbox/dcim/filtersets.py +++ b/netbox/dcim/filtersets.py @@ -271,7 +271,7 @@ class LocationFilterSet(TenancyFilterSet, ContactModelFilterSet, OrganizationalM class Meta: model = Location - fields = ('id', 'name', 'slug', 'status', 'facility', 'description') + fields = ('id', 'name', 'slug', 'facility', 'description') def search(self, queryset, name, value): if not value.strip(): diff --git a/netbox/netbox/graphql/filter_mixins.py b/netbox/netbox/graphql/filter_mixins.py index 65c7ffcef5a..2044a1ddeeb 100644 --- a/netbox/netbox/graphql/filter_mixins.py +++ b/netbox/netbox/graphql/filter_mixins.py @@ -1,11 +1,12 @@ -from functools import partial, partialmethod, wraps +from functools import partialmethod from typing import List import django_filters import strawberry import strawberry_django -from django.core.exceptions import FieldDoesNotExist, ValidationError +from django.core.exceptions import FieldDoesNotExist from strawberry import auto + from ipam.fields import ASNField from netbox.graphql.scalars import BigInt from utilities.fields import ColorField, CounterCacheField @@ -108,8 +109,7 @@ def map_strawberry_type(field): elif issubclass(type(field), django_filters.TypedMultipleChoiceFilter): pass elif issubclass(type(field), django_filters.MultipleChoiceFilter): - should_create_function = True - attr_type = List[str] | None + attr_type = str | None elif issubclass(type(field), django_filters.TypedChoiceFilter): pass elif issubclass(type(field), django_filters.ChoiceFilter): diff --git a/netbox/netbox/tests/test_graphql.py b/netbox/netbox/tests/test_graphql.py index 34ea3ad6a20..b04d42d2447 100644 --- a/netbox/netbox/tests/test_graphql.py +++ b/netbox/netbox/tests/test_graphql.py @@ -5,8 +5,8 @@ from rest_framework import status from core.models import ObjectType +from dcim.choices import LocationStatusChoices from dcim.models import Site, Location -from ipam.models import ASN, RIR from users.models import ObjectPermission from utilities.testing import disable_warnings, APITestCase, TestCase @@ -53,10 +53,27 @@ def test_graphql_filter_objects(self): sites = ( Site(name='Site 1', slug='site-1'), Site(name='Site 2', slug='site-2'), + Site(name='Site 3', slug='site-3'), ) Site.objects.bulk_create(sites) - Location.objects.create(site=sites[0], name='Location 1', slug='location-1'), - Location.objects.create(site=sites[1], name='Location 2', slug='location-2'), + Location.objects.create( + site=sites[0], + name='Location 1', + slug='location-1', + status=LocationStatusChoices.STATUS_PLANNED + ), + Location.objects.create( + site=sites[1], + name='Location 2', + slug='location-2', + status=LocationStatusChoices.STATUS_STAGING + ), + Location.objects.create( + site=sites[1], + name='Location 3', + slug='location-3', + status=LocationStatusChoices.STATUS_ACTIVE + ), # Add object-level permission obj_perm = ObjectPermission( @@ -68,8 +85,9 @@ def test_graphql_filter_objects(self): obj_perm.object_types.add(ObjectType.objects.get_for_model(Location)) obj_perm.object_types.add(ObjectType.objects.get_for_model(Site)) - # A valid request should return the filtered list url = reverse('graphql') + + # A valid request should return the filtered list query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}' response = self.client.post(url, data={'query': query}, format="json", **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) @@ -78,6 +96,21 @@ def test_graphql_filter_objects(self): self.assertEqual(len(data['data']['location_list']), 1) self.assertIsNotNone(data['data']['location_list'][0]['site']) + # Test OR logic + query = """{ + location_list( filters: { + status: \"""" + LocationStatusChoices.STATUS_PLANNED + """\", + OR: {status: \"""" + LocationStatusChoices.STATUS_STAGING + """\"} + }) { + id site {id} + } + }""" + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + self.assertEqual(len(data['data']['location_list']), 2) + # An invalid request should return an empty list query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID response = self.client.post(url, data={'query': query}, format="json", **self.header)