From 70b5ccf3577021572e67a7c456ec6a3f6fd54b60 Mon Sep 17 00:00:00 2001 From: slingshotvfx <146885925+slingshotvfx@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:46:44 -0800 Subject: [PATCH] add multi_entity_update_modes support to mockgun --- shotgun_api3/lib/mockgun/mockgun.py | 28 ++++++-- tests/test_mockgun.py | 100 ++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 5 deletions(-) diff --git a/shotgun_api3/lib/mockgun/mockgun.py b/shotgun_api3/lib/mockgun/mockgun.py index 72665bba..4bf1e4bf 100644 --- a/shotgun_api3/lib/mockgun/mockgun.py +++ b/shotgun_api3/lib/mockgun/mockgun.py @@ -335,7 +335,12 @@ def batch(self, requests): results.append(self.create(request["entity_type"], request["data"])) elif request["request_type"] == "update": # note: Shotgun.update returns a list of a single item - results.append(self.update(request["entity_type"], request["entity_id"], request["data"])[0]) + results.append( + self.update(request["entity_type"], + request["entity_id"], + request["data"], + request.get("multi_entity_update_modes"))[0] + ) elif request["request_type"] == "delete": results.append(self.delete(request["entity_type"], request["entity_id"])) else: @@ -387,13 +392,13 @@ def create(self, entity_type, data, return_fields=None): return result - def update(self, entity_type, entity_id, data): + def update(self, entity_type, entity_id, data, multi_entity_update_modes=None): self._validate_entity_type(entity_type) self._validate_entity_data(entity_type, data) self._validate_entity_exists(entity_type, entity_id) row = self._db[entity_type][entity_id] - self._update_row(entity_type, row, data) + self._update_row(entity_type, row, data, multi_entity_update_modes) return [dict((field, item) for field, item in row.items() if field in data or field in ("type", "id"))] @@ -818,13 +823,26 @@ def _row_matches_filters(self, entity_type, row, filters, filter_operator, retir else: raise ShotgunError("%s is not a valid filter operator" % filter_operator) - def _update_row(self, entity_type, row, data): + def _update_row(self, entity_type, row, data, multi_entity_update_modes=None): for field in data: field_type = self._get_field_type(entity_type, field) if field_type == "entity" and data[field]: row[field] = {"type": data[field]["type"], "id": data[field]["id"]} elif field_type == "multi_entity": - row[field] = [{"type": item["type"], "id": item["id"]} for item in data[field]] + update_mode = multi_entity_update_modes.get(field, "set") if multi_entity_update_modes else "set" + + if update_mode == "add": + row[field] += [{"type": item["type"], "id": item["id"]} for item in data[field]] + elif update_mode == "remove": + row[field] = [ + item + for item in row[field] + for new_item in data[field] + if item["id"] != new_item["id"] + or item["type"] != new_item["type"] + ] + elif update_mode == "set": + row[field] = [{"type": item["type"], "id": item["id"]} for item in data[field]] else: row[field] = data[field] diff --git a/tests/test_mockgun.py b/tests/test_mockgun.py index 08976d2a..84e5cb2e 100644 --- a/tests/test_mockgun.py +++ b/tests/test_mockgun.py @@ -270,6 +270,106 @@ def test_find_with_none(self): for item in items: self.assertTrue(len(item["users"]) > 0) + +class TestMultiEntityFieldUpdate(unittest.TestCase): + """ + Ensures multi entity field update modes work. + """ + + def setUp(self): + """ + Creates test data. + """ + + self._mockgun = Mockgun("https://test.shotgunstudio.com", login="user", password="1234") + + # Create two versions to assign to the shot. + self._version1 = self._mockgun.create("Version", {"code": "version1"}) + self._version2 = self._mockgun.create("Version", {"code": "version2"}) + self._version3 = self._mockgun.create("Version", {"code": "version3"}) + + # remove 'code' field for later comparisons + del self._version1["code"] + del self._version2["code"] + del self._version3["code"] + + # Create playlists + self._add_playlist = self._mockgun.create( + "Playlist", + {"code": "playlist1", "versions": [self._version1, self._version2]} + ) + self._remove_playlist = self._mockgun.create( + "Playlist", + {"code": "playlist1", "versions": [self._version1, self._version2, self._version3]} + ) + self._set_playlist = self._mockgun.create( + "Playlist", + {"code": "playlist1", "versions": [self._version1, self._version2]} + ) + + def test_update_add(self): + """ + Ensures that "add" multi_entity_update_mode works. + """ + self._mockgun.update( + "Playlist", self._add_playlist["id"], {"versions": [self._version3]}, + multi_entity_update_modes={"versions": "add"} + ) + + playlist = self._mockgun.find_one( + "Playlist", [["id", "is", self._add_playlist["id"]]], ["versions"] + ) + self.assertEqual( + playlist["versions"], [self._version1, self._version2, self._version3] + ) + + def test_update_remove(self): + """ + Ensures that "remove" multi_entity_update_mode works. + """ + self._mockgun.update( + "Playlist", self._remove_playlist["id"], {"versions": [self._version2]}, + multi_entity_update_modes={"versions": "remove"} + ) + + playlist = self._mockgun.find_one( + "Playlist", [["id", "is", self._remove_playlist["id"]]], ["versions"] + ) + self.assertEqual(playlist["versions"], [self._version1, self._version3]) + + def test_update_set(self): + """ + Ensures that "set" multi_entity_update_mode works. + """ + self._mockgun.update( + "Playlist", + self._set_playlist["id"], + {"versions": [self._version2, self._version3]}, + multi_entity_update_modes={"versions": "set"} + ) + + playlist = self._mockgun.find_one( + "Playlist", [["id", "is", self._set_playlist["id"]]], ["versions"] + ) + self.assertEqual(playlist["versions"], [self._version2, self._version3]) + + def test_batch_update(self): + self._mockgun.batch( + [ + { + "request_type": "update", + "entity_type": "Playlist", + "entity_id": self._set_playlist["id"], + "data": {"versions": [self._version1, self._version2]}, + "multi_entity_update_modes": {"versions": "set"} + } + ] + ) + playlist = self._mockgun.find_one( + "Playlist", [["id", "is", self._set_playlist["id"]]], ["versions"] + ) + self.assertEqual(playlist["versions"], [self._version1, self._version2]) + class TestFilterOperator(unittest.TestCase): """