Skip to content
165 changes: 116 additions & 49 deletions netbox/dcim/models/cables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from utilities.exceptions import AbortRequest
from utilities.fields import ColorField, GenericArrayForeignKey
from utilities.querysets import RestrictedQuerySet
from utilities.serialization import deserialize_object, serialize_object
from wireless.models import WirelessLink
from .device_components import FrontPort, RearPort, PathEndpoint

Expand Down Expand Up @@ -119,43 +120,61 @@ def __str__(self):
pk = self.pk or self._pk
return self.label or f'#{pk}'

@property
def a_terminations(self):
if hasattr(self, '_a_terminations'):
return self._a_terminations
def get_status_color(self):
return LinkStatusChoices.colors.get(self.status)

def _get_x_terminations(self, side):
"""
Return the terminating objects for the given cable end (A or B).
"""
if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B):
raise ValueError(f"Unknown cable side: {side}")
attr = f'_{side.lower()}_terminations'

if hasattr(self, attr):
return getattr(self, attr)
if not self.pk:
return []

# Query self.terminations.all() to leverage cached results
return [
ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_A
# Query self.terminations.all() to leverage cached results
ct.termination for ct in self.terminations.all() if ct.cable_end == side
]

@a_terminations.setter
def a_terminations(self, value):
if not self.pk or self.a_terminations != list(value):
def _set_x_terminations(self, side, value):
"""
Set the terminating objects for the given cable end (A or B).
"""
if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B):
raise ValueError(f"Unknown cable side: {side}")
_attr = f'_{side.lower()}_terminations'

# If the provided value is a list of CableTermination IDs, resolve them
# to their corresponding termination objects.
if all(isinstance(item, int) for item in value):
value = [
ct.termination for ct in CableTermination.objects.filter(pk__in=value).prefetch_related('termination')
]

if not self.pk or getattr(self, _attr, []) != list(value):
self._terminations_modified = True
self._a_terminations = value

setattr(self, _attr, value)

@property
def b_terminations(self):
if hasattr(self, '_b_terminations'):
return self._b_terminations
def a_terminations(self):
return self._get_x_terminations(CableEndChoices.SIDE_A)

if not self.pk:
return []
@a_terminations.setter
def a_terminations(self, value):
self._set_x_terminations(CableEndChoices.SIDE_A, value)

# Query self.terminations.all() to leverage cached results
return [
ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_B
]
@property
def b_terminations(self):
return self._get_x_terminations(CableEndChoices.SIDE_B)

@b_terminations.setter
def b_terminations(self, value):
if not self.pk or self.b_terminations != list(value):
self._terminations_modified = True
self._b_terminations = value
self._set_x_terminations(CableEndChoices.SIDE_B, value)

@property
def color_name(self):
Expand Down Expand Up @@ -208,7 +227,7 @@ def clean(self):
for termination in self.b_terminations:
CableTermination(cable=self, cable_end='B', termination=termination).clean()

def save(self, *args, **kwargs):
def save(self, *args, force_insert=False, force_update=False, using=None, update_fields=None):
_created = self.pk is None

# Store the given length (if any) in meters for use in database ordering
Expand All @@ -221,39 +240,87 @@ def save(self, *args, **kwargs):
if self.length is None:
self.length_unit = None

super().save(*args, **kwargs)
# If this is a new Cable, save it before attempting to create its CableTerminations
if self._state.adding:
super().save(*args, force_insert=True, using=using, update_fields=update_fields)
# Update the private PK used in __str__()
self._pk = self.pk

# Update the private pk used in __str__ in case this is a new object (i.e. just got its pk)
self._pk = self.pk
if self._terminations_modified:
self.update_terminations()

# Retrieve existing A/B terminations for the Cable
a_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='A')}
b_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='B')}
super().save(*args, force_update=True, using=using, update_fields=update_fields)

# Delete stale CableTerminations
if self._terminations_modified:
for termination, ct in a_terminations.items():
if termination.pk and termination not in self.a_terminations:
ct.delete()
for termination, ct in b_terminations.items():
if termination.pk and termination not in self.b_terminations:
ct.delete()

# Save new CableTerminations (if any)
if self._terminations_modified:
for termination in self.a_terminations:
if not termination.pk or termination not in a_terminations:
CableTermination(cable=self, cable_end='A', termination=termination).save()
for termination in self.b_terminations:
if not termination.pk or termination not in b_terminations:
CableTermination(cable=self, cable_end='B', termination=termination).save()
try:
trace_paths.send(Cable, instance=self, created=_created)
except UnsupportedCablePath as e:
raise AbortRequest(e)

def get_status_color(self):
return LinkStatusChoices.colors.get(self.status)
def serialize_object(self, exclude=None):
data = serialize_object(self, exclude=exclude or [])

# Add A & B terminations to the serialized data
a_terminations, b_terminations = self.get_terminations()
data['a_terminations'] = sorted([ct.pk for ct in a_terminations.values()])
data['b_terminations'] = sorted([ct.pk for ct in b_terminations.values()])

return data

@classmethod
def deserialize_object(cls, data, pk=None):
a_terminations = data.pop('a_terminations', [])
b_terminations = data.pop('b_terminations', [])

instance = deserialize_object(cls, data, pk=pk)

# Assign A & B termination objects to the Cable instance
queryset = CableTermination.objects.prefetch_related('termination')
instance.a_terminations = [
ct.termination for ct in queryset.filter(pk__in=a_terminations)
]
instance.b_terminations = [
ct.termination for ct in queryset.filter(pk__in=b_terminations)
]

return instance

def get_terminations(self):
"""
Return two dictionaries mapping A & B side terminating objects to their corresponding CableTerminations
for this Cable.
"""
a_terminations = {}
b_terminations = {}

for ct in CableTermination.objects.filter(cable=self).prefetch_related('termination'):
if ct.cable_end == CableEndChoices.SIDE_A:
a_terminations[ct.termination] = ct
else:
b_terminations[ct.termination] = ct

return a_terminations, b_terminations

def update_terminations(self):
"""
Create/delete CableTerminations for this Cable to reflect its current state.
"""
a_terminations, b_terminations = self.get_terminations()

# Delete any stale CableTerminations
for termination, ct in a_terminations.items():
if termination.pk and termination not in self.a_terminations:
ct.delete()
for termination, ct in b_terminations.items():
if termination.pk and termination not in self.b_terminations:
ct.delete()

# Save any new CableTerminations
for termination in self.a_terminations:
if not termination.pk or termination not in a_terminations:
CableTermination(cable=self, cable_end='A', termination=termination).save()
for termination in self.b_terminations:
if not termination.pk or termination not in b_terminations:
CableTermination(cable=self, cable_end='B', termination=termination).save()


class CableTermination(ChangeLoggedModel):
Expand Down
8 changes: 4 additions & 4 deletions netbox/utilities/testing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def test_create_object(self):
if issubclass(self.model, ChangeLoggingMixin):
objectchange = ObjectChange.objects.get(
changed_object_type=ContentType.objects.get_for_model(instance),
changed_object_id=instance.pk
changed_object_id=instance.pk,
action=ObjectChangeActionChoices.ACTION_CREATE,
)
self.assertEqual(objectchange.action, ObjectChangeActionChoices.ACTION_CREATE)
self.assertEqual(objectchange.message, data['changelog_message'])

def test_bulk_create_objects(self):
Expand Down Expand Up @@ -298,11 +298,11 @@ def test_bulk_create_objects(self):
]
objectchanges = ObjectChange.objects.filter(
changed_object_type=ContentType.objects.get_for_model(self.model),
changed_object_id__in=id_list
changed_object_id__in=id_list,
action=ObjectChangeActionChoices.ACTION_CREATE,
)
self.assertEqual(len(objectchanges), len(self.create_data))
for oc in objectchanges:
self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE)
self.assertEqual(oc.message, changelog_message)

class UpdateObjectViewTestCase(APITestCase):
Expand Down
4 changes: 2 additions & 2 deletions netbox/utilities/testing/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,11 @@ def test_bulk_import_objects_with_permission(self):
self.assertIsNotNone(request_id, "Unable to determine request ID from response")
objectchanges = ObjectChange.objects.filter(
changed_object_type=ContentType.objects.get_for_model(self.model),
request_id=request_id
request_id=request_id,
action=ObjectChangeActionChoices.ACTION_CREATE,
)
self.assertEqual(len(objectchanges), len(self.csv_data) - 1)
for oc in objectchanges:
self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE)
self.assertEqual(oc.message, data['changelog_message'])

@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
Expand Down