Skip to content

Commit 99467e8

Browse files
Fixes #12731: Support custom validation for many-to-many fields (#14516)
* WIP * Enforce custom validators during bulk edit * Add bulk edit M2M validation test * Clean up tests * Add custom validation test for bulk import * Misc cleanup
1 parent 0d08205 commit 99467e8

File tree

5 files changed

+314
-9
lines changed

5 files changed

+314
-9
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from django.test import TestCase
2+
from django.test import override_settings
3+
4+
from circuits.api.serializers import ProviderSerializer
5+
from circuits.forms import ProviderForm
6+
from circuits.models import Provider
7+
from ipam.models import ASN, RIR
8+
from utilities.choices import CSVDelimiterChoices, ImportFormatChoices
9+
from utilities.testing import APITestCase, ModelViewTestCase, create_tags, post_data
10+
11+
12+
class ModelFormCustomValidationTest(TestCase):
13+
14+
@override_settings(CUSTOM_VALIDATORS={
15+
'circuits.provider': [
16+
{'tags': {'required': True}}
17+
]
18+
})
19+
def test_tags_validation(self):
20+
"""
21+
Check that custom validation rules work for tag assignment.
22+
"""
23+
data = {
24+
'name': 'Provider 1',
25+
'slug': 'provider-1',
26+
}
27+
form = ProviderForm(data)
28+
self.assertFalse(form.is_valid())
29+
30+
tags = create_tags('Tag1', 'Tag2', 'Tag3')
31+
data['tags'] = [tag.pk for tag in tags]
32+
form = ProviderForm(data)
33+
self.assertTrue(form.is_valid())
34+
35+
@override_settings(CUSTOM_VALIDATORS={
36+
'circuits.provider': [
37+
{'asns': {'required': True}}
38+
]
39+
})
40+
def test_m2m_validation(self):
41+
"""
42+
Check that custom validation rules work for many-to-many fields.
43+
"""
44+
data = {
45+
'name': 'Provider 1',
46+
'slug': 'provider-1',
47+
}
48+
form = ProviderForm(data)
49+
self.assertFalse(form.is_valid())
50+
51+
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
52+
asns = ASN.objects.bulk_create((
53+
ASN(rir=rir, asn=65001),
54+
ASN(rir=rir, asn=65002),
55+
ASN(rir=rir, asn=65003),
56+
))
57+
data['asns'] = [asn.pk for asn in asns]
58+
form = ProviderForm(data)
59+
self.assertTrue(form.is_valid())
60+
61+
62+
class BulkEditCustomValidationTest(ModelViewTestCase):
63+
model = Provider
64+
65+
@classmethod
66+
def setUpTestData(cls):
67+
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
68+
asns = ASN.objects.bulk_create((
69+
ASN(rir=rir, asn=65001),
70+
ASN(rir=rir, asn=65002),
71+
ASN(rir=rir, asn=65003),
72+
))
73+
74+
providers = (
75+
Provider(name='Provider 1', slug='provider-1'),
76+
Provider(name='Provider 2', slug='provider-2'),
77+
Provider(name='Provider 3', slug='provider-3'),
78+
)
79+
Provider.objects.bulk_create(providers)
80+
for provider in providers:
81+
provider.asns.set(asns)
82+
83+
@override_settings(CUSTOM_VALIDATORS={
84+
'circuits.provider': [
85+
{'asns': {'required': True}}
86+
]
87+
})
88+
def test_bulk_edit_without_m2m(self):
89+
"""
90+
Check that custom validation rules do not interfere with bulk editing.
91+
"""
92+
data = {
93+
'pk': list(Provider.objects.values_list('pk', flat=True)),
94+
'_apply': '',
95+
'description': 'New description',
96+
}
97+
self.add_permissions(
98+
'circuits.view_provider',
99+
'circuits.change_provider',
100+
)
101+
102+
# Bulk edit the description without changing ASN assignments
103+
request = {
104+
'path': self._get_url('bulk_edit'),
105+
'data': post_data(data),
106+
}
107+
response = self.client.post(**request)
108+
self.assertHttpStatus(response, 302)
109+
self.assertEqual(
110+
Provider.objects.filter(description=data['description']).count(),
111+
len(data['pk'])
112+
)
113+
114+
@override_settings(CUSTOM_VALIDATORS={
115+
'circuits.provider': [
116+
{'asns': {'required': True}}
117+
]
118+
})
119+
def test_bulk_edit_m2m(self):
120+
"""
121+
Test that custom validation rules are enforced during bulk editing.
122+
"""
123+
data = {
124+
'pk': list(Provider.objects.values_list('pk', flat=True)),
125+
'_apply': '',
126+
'description': 'New description',
127+
}
128+
self.add_permissions(
129+
'circuits.view_provider',
130+
'circuits.change_provider',
131+
'ipam.view_asn',
132+
)
133+
134+
# Change the ASN assignments
135+
asn = ASN.objects.first()
136+
data['asns'] = [asn.pk]
137+
request = {
138+
'path': self._get_url('bulk_edit'),
139+
'data': post_data(data),
140+
}
141+
response = self.client.post(**request)
142+
self.assertHttpStatus(response, 302)
143+
for provider in Provider.objects.all():
144+
self.assertEqual(len(provider.asns.all()), 1)
145+
146+
# Attempt to remove the ASN assignments
147+
data.pop('asns')
148+
data['_nullify'] = 'asns'
149+
request = {
150+
'path': self._get_url('bulk_edit'),
151+
'data': post_data(data),
152+
}
153+
response = self.client.post(**request)
154+
self.assertHttpStatus(response, 200)
155+
for provider in Provider.objects.all():
156+
self.assertTrue(provider.asns.exists())
157+
158+
159+
class BulkImportCustomValidationTest(ModelViewTestCase):
160+
model = Provider
161+
162+
@classmethod
163+
def setUpTestData(cls):
164+
create_tags('Tag1', 'Tag2', 'Tag3')
165+
166+
@override_settings(CUSTOM_VALIDATORS={
167+
'circuits.provider': [
168+
{'tags': {'required': True}}
169+
]
170+
})
171+
def test_bulk_import_invalid(self):
172+
"""
173+
Test that custom validation rules are enforced during bulk import.
174+
"""
175+
csv_data = (
176+
"name,slug",
177+
"Provider 1,provider-1",
178+
"Provider 2,provider-2",
179+
"Provider 3,provider-3",
180+
)
181+
data = {
182+
'data': '\n'.join(csv_data),
183+
'format': ImportFormatChoices.CSV,
184+
'csv_delimiter': CSVDelimiterChoices.COMMA,
185+
}
186+
self.add_permissions(
187+
'circuits.view_provider',
188+
'circuits.add_provider',
189+
'extras.view_tag',
190+
)
191+
192+
# Attempt to import providers without tags
193+
request = {
194+
'path': self._get_url('import'),
195+
'data': post_data(data),
196+
}
197+
response = self.client.post(**request)
198+
self.assertHttpStatus(response, 200)
199+
self.assertFalse(Provider.objects.exists())
200+
201+
# Import providers successfully with tag assignments
202+
csv_data = (
203+
"name,slug,tags",
204+
"Provider 1,provider-1,tag1",
205+
"Provider 2,provider-2,tag2",
206+
"Provider 3,provider-3,tag3",
207+
)
208+
data['data'] = '\n'.join(csv_data)
209+
request = {
210+
'path': self._get_url('import'),
211+
'data': post_data(data),
212+
}
213+
response = self.client.post(**request)
214+
self.assertHttpStatus(response, 302)
215+
self.assertTrue(Provider.objects.exists())
216+
217+
218+
class APISerializerCustomValidationTest(APITestCase):
219+
220+
@override_settings(CUSTOM_VALIDATORS={
221+
'circuits.provider': [
222+
{'tags': {'required': True}}
223+
]
224+
})
225+
def test_tags_validation(self):
226+
"""
227+
Check that custom validation rules work for tag assignment.
228+
"""
229+
data = {
230+
'name': 'Provider 1',
231+
'slug': 'provider-1',
232+
}
233+
serializer = ProviderSerializer(data=data)
234+
self.assertFalse(serializer.is_valid())
235+
236+
tags = create_tags('Tag1', 'Tag2', 'Tag3')
237+
data['tags'] = [tag.pk for tag in tags]
238+
serializer = ProviderSerializer(data=data)
239+
self.assertTrue(serializer.is_valid())
240+
241+
@override_settings(CUSTOM_VALIDATORS={
242+
'circuits.provider': [
243+
{'asns': {'required': True}}
244+
]
245+
})
246+
def test_m2m_validation(self):
247+
"""
248+
Check that custom validation rules work for many-to-many fields.
249+
"""
250+
data = {
251+
'name': 'Provider 1',
252+
'slug': 'provider-1',
253+
}
254+
serializer = ProviderSerializer(data=data)
255+
self.assertFalse(serializer.is_valid())
256+
257+
rir = RIR.objects.create(name='RIR 1', slug='rir-1')
258+
asns = ASN.objects.bulk_create((
259+
ASN(rir=rir, asn=65001),
260+
ASN(rir=rir, asn=65002),
261+
ASN(rir=rir, asn=65003),
262+
))
263+
data['asns'] = [asn.pk for asn in asns]
264+
serializer = ProviderSerializer(data=data)
265+
self.assertTrue(serializer.is_valid())

netbox/extras/validators.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from django.core.exceptions import ValidationError
21
from django.core import validators
2+
from django.core.exceptions import ValidationError
3+
from django.utils.translation import gettext_lazy as _
34

45
# NOTE: As this module may be imported by configuration.py, we cannot import
56
# anything from NetBox itself.
@@ -66,8 +67,7 @@ def __init__(self, validation_rules=None):
6667
def __call__(self, instance):
6768
# Validate instance attributes per validation rules
6869
for attr_name, rules in self.validation_rules.items():
69-
assert hasattr(instance, attr_name), f"Invalid attribute '{attr_name}' for {instance.__class__.__name__}"
70-
attr = getattr(instance, attr_name)
70+
attr = self._getattr(instance, attr_name)
7171
for descriptor, value in rules.items():
7272
validator = self.get_validator(descriptor, value)
7373
try:
@@ -79,6 +79,26 @@ def __call__(self, instance):
7979
# Execute custom validation logic (if any)
8080
self.validate(instance)
8181

82+
@staticmethod
83+
def _getattr(instance, name):
84+
# Attempt to resolve many-to-many fields to their stored values
85+
m2m_fields = [f.name for f in instance._meta.local_many_to_many]
86+
if name in m2m_fields:
87+
if name in getattr(instance, '_m2m_values', []):
88+
return instance._m2m_values[name]
89+
if instance.pk:
90+
return list(getattr(instance, name).all())
91+
return []
92+
93+
# Raise a ValidationError for unknown attributes
94+
if not hasattr(instance, name):
95+
raise ValidationError(_('Invalid attribute "{name}" for {model}').format(
96+
name=name,
97+
model=instance.__class__.__name__
98+
))
99+
100+
return getattr(instance, name)
101+
82102
def get_validator(self, descriptor, value):
83103
"""
84104
Instantiate and return the appropriate validator based on the descriptor given. For

netbox/netbox/api/serializers/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ class ValidatedModelSerializer(BaseModelSerializer):
2323
validation. (DRF does not do this by default; see https://github.com/encode/django-rest-framework/issues/3144)
2424
"""
2525
def validate(self, data):
26-
27-
# Remove custom fields data and tags (if any) prior to model validation
2826
attrs = data.copy()
27+
28+
# Remove custom field data (if any) prior to model validation
2929
attrs.pop('custom_fields', None)
30-
attrs.pop('tags', None)
3130

3231
# Skip ManyToManyFields
33-
for field in self.Meta.model._meta.get_fields():
34-
if isinstance(field, ManyToManyField):
35-
attrs.pop(field.name, None)
32+
m2m_values = {}
33+
for field in self.Meta.model._meta.local_many_to_many:
34+
if field.name in attrs:
35+
m2m_values[field.name] = attrs.pop(field.name)
3636

3737
# Run clean() on an instance of the model
3838
if self.instance is None:
@@ -41,6 +41,7 @@ def validate(self, data):
4141
instance = self.instance
4242
for k, v in attrs.items():
4343
setattr(instance, k, v)
44+
instance._m2m_values = m2m_values
4445
instance.full_clean()
4546

4647
return data

netbox/netbox/forms/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ def clean(self):
5757

5858
return super().clean()
5959

60+
def _post_clean(self):
61+
"""
62+
Override BaseModelForm's _post_clean() to store many-to-many field values on the model instance.
63+
"""
64+
self.instance._m2m_values = {}
65+
for field in self.instance._meta.local_many_to_many:
66+
if field.name in self.cleaned_data:
67+
self.instance._m2m_values[field.name] = list(self.cleaned_data[field.name])
68+
69+
return super()._post_clean()
70+
6071

6172
class NetBoxModelImportForm(CSVModelForm, NetBoxModelForm):
6273
"""

netbox/netbox/views/generic/bulk_views.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,14 @@ def _update_objects(self, form, request):
557557
elif name in form.changed_data:
558558
obj.custom_field_data[cf_name] = customfield.serialize(form.cleaned_data[name])
559559

560+
# Store M2M values for validation
561+
obj._m2m_values = {}
562+
for field in obj._meta.local_many_to_many:
563+
if value := form.cleaned_data.get(field.name):
564+
obj._m2m_values[field.name] = list(value)
565+
elif field.name in nullified_fields:
566+
obj._m2m_values[field.name] = []
567+
560568
obj.full_clean()
561569
obj.save()
562570
updated_objects.append(obj)

0 commit comments

Comments
 (0)