From 9dfd9cfc28645c799f8361bea0047230dd45868d Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:11:41 +0100 Subject: [PATCH 01/18] [core] Move `nodeFactory` to its own module --- meshroom/core/graph.py | 3 +- meshroom/core/node.py | 110 --------------------------------- meshroom/core/nodeFactory.py | 116 +++++++++++++++++++++++++++++++++++ meshroom/ui/commands.py | 3 +- 4 files changed, 120 insertions(+), 112 deletions(-) create mode 100644 meshroom/core/nodeFactory.py diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index e63aceca1a..0a672b388c 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -16,7 +16,8 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit -from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode +from meshroom.core.node import Status, Node, CompatibilityNode +from meshroom.core.nodeFactory import nodeFactory # Replace default encoder to support Enums diff --git a/meshroom/core/node.py b/meshroom/core/node.py index 1b8806e2e4..a4c7f76ddf 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -1851,113 +1851,3 @@ def upgrade(self): canUpgrade = Property(bool, canUpgrade.fget, constant=True) issueDetails = Property(str, issueDetails.fget, constant=True) - -def nodeFactory(nodeDict, name=None, template=False, uidConflict=False): - """ - Create a node instance by deserializing the given node data. - If the serialized data matches the corresponding node type description, a Node instance is created. - If any compatibility issue occurs, a NodeCompatibility instance is created instead. - - Args: - nodeDict (dict): the serialization of the node - name (str): (optional) the node's name - template (bool): (optional) true if the node is part of a template, false otherwise - uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node - - Returns: - BaseNode: the created node - """ - nodeType = nodeDict["nodeType"] - - # Retro-compatibility: inputs were previously saved as "attributes" - if "inputs" not in nodeDict and "attributes" in nodeDict: - nodeDict["inputs"] = nodeDict["attributes"] - del nodeDict["attributes"] - - # Get node inputs/outputs - inputs = nodeDict.get("inputs", {}) - internalInputs = nodeDict.get("internalInputs", {}) - outputs = nodeDict.get("outputs", {}) - version = nodeDict.get("version", None) - internalFolder = nodeDict.get("internalFolder", None) - position = Position(*nodeDict.get("position", [])) - uid = nodeDict.get("uid", None) - - compatibilityIssue = None - - nodeDesc = None - try: - nodeDesc = meshroom.core.nodesDesc[nodeType] - except KeyError: - # Unknown node type - compatibilityIssue = CompatibilityIssue.UnknownNodeType - - # Unknown node type should take precedence over UID conflict, as it cannot be resolved - if uidConflict and nodeDesc: - compatibilityIssue = CompatibilityIssue.UidConflict - - if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue - # Compare serialized node version with current node version - currentNodeVersion = meshroom.core.nodeVersion(nodeDesc) - # If both versions are available, check for incompatibility in major version - if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major: - compatibilityIssue = CompatibilityIssue.VersionConflict - # In other cases, check attributes compatibility between serialized node and its description - else: - # Check that the node has the exact same set of inputs/outputs as its description, except - # if the node is described in a template file, in which only non-default parameters are saved; - # do not perform that check for internal attributes because there is no point in - # raising compatibility issues if their number differs: in that case, it is only useful - # if some internal attributes do not exist or are invalid - if not template and (sorted([attr.name for attr in nodeDesc.inputs - if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or - sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != - sorted(outputs.keys())): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - - # Check whether there are any internal attributes that are invalidating in the node description: if there - # are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility - # issue must be raised to warn the user, as this will automatically change the node's UID - if not template: - invalidatingIntInputs = [] - for attr in nodeDesc.internalInputs: - if attr.invalidate: - invalidatingIntInputs.append(attr.name) - for attr in invalidatingIntInputs: - if attr not in internalInputs.keys(): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - # Verify that all inputs match their descriptions - for attrName, value in inputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all internal inputs match their description - for attrName, value in internalInputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all outputs match their descriptions - for attrName, value in outputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - if compatibilityIssue is None: - node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs) - else: - logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name)) - node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue) - # Retro-compatibility: no internal folder saved - # can't spawn meaningful CompatibilityNode with precomputed outputs - # => automatically try to perform node upgrade - if not internalFolder and nodeDesc: - logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name)) - node = node.upgrade() - # If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is - # an "unknown node type" conflict (in which case the upgrade would fail) - elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType: - node = node.upgrade() - - return node diff --git a/meshroom/core/nodeFactory.py b/meshroom/core/nodeFactory.py new file mode 100644 index 0000000000..2c54f1b44a --- /dev/null +++ b/meshroom/core/nodeFactory.py @@ -0,0 +1,116 @@ +import logging + +import meshroom.core +from meshroom.core import Version, desc +from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position + + +def nodeFactory(nodeDict, name=None, template=False, uidConflict=False): + """ + Create a node instance by deserializing the given node data. + If the serialized data matches the corresponding node type description, a Node instance is created. + If any compatibility issue occurs, a NodeCompatibility instance is created instead. + + Args: + nodeDict (dict): the serialization of the node + name (str): (optional) the node's name + template (bool): (optional) true if the node is part of a template, false otherwise + uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node + + Returns: + BaseNode: the created node + """ + nodeType = nodeDict["nodeType"] + + # Retro-compatibility: inputs were previously saved as "attributes" + if "inputs" not in nodeDict and "attributes" in nodeDict: + nodeDict["inputs"] = nodeDict["attributes"] + del nodeDict["attributes"] + + # Get node inputs/outputs + inputs = nodeDict.get("inputs", {}) + internalInputs = nodeDict.get("internalInputs", {}) + outputs = nodeDict.get("outputs", {}) + version = nodeDict.get("version", None) + internalFolder = nodeDict.get("internalFolder", None) + position = Position(*nodeDict.get("position", [])) + uid = nodeDict.get("uid", None) + + compatibilityIssue = None + + nodeDesc = None + try: + nodeDesc = meshroom.core.nodesDesc[nodeType] + except KeyError: + # Unknown node type + compatibilityIssue = CompatibilityIssue.UnknownNodeType + + # Unknown node type should take precedence over UID conflict, as it cannot be resolved + if uidConflict and nodeDesc: + compatibilityIssue = CompatibilityIssue.UidConflict + + if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue + # Compare serialized node version with current node version + currentNodeVersion = meshroom.core.nodeVersion(nodeDesc) + # If both versions are available, check for incompatibility in major version + if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major: + compatibilityIssue = CompatibilityIssue.VersionConflict + # In other cases, check attributes compatibility between serialized node and its description + else: + # Check that the node has the exact same set of inputs/outputs as its description, except + # if the node is described in a template file, in which only non-default parameters are saved; + # do not perform that check for internal attributes because there is no point in + # raising compatibility issues if their number differs: in that case, it is only useful + # if some internal attributes do not exist or are invalid + if not template and (sorted([attr.name for attr in nodeDesc.inputs + if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or + sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != + sorted(outputs.keys())): + compatibilityIssue = CompatibilityIssue.DescriptionConflict + + # Check whether there are any internal attributes that are invalidating in the node description: if there + # are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility + # issue must be raised to warn the user, as this will automatically change the node's UID + if not template: + invalidatingIntInputs = [] + for attr in nodeDesc.internalInputs: + if attr.invalidate: + invalidatingIntInputs.append(attr.name) + for attr in invalidatingIntInputs: + if attr not in internalInputs.keys(): + compatibilityIssue = CompatibilityIssue.DescriptionConflict + break + + # Verify that all inputs match their descriptions + for attrName, value in inputs.items(): + if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value): + compatibilityIssue = CompatibilityIssue.DescriptionConflict + break + # Verify that all internal inputs match their description + for attrName, value in internalInputs.items(): + if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value): + compatibilityIssue = CompatibilityIssue.DescriptionConflict + break + # Verify that all outputs match their descriptions + for attrName, value in outputs.items(): + if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value): + compatibilityIssue = CompatibilityIssue.DescriptionConflict + break + + if compatibilityIssue is None: + node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs) + else: + logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name)) + node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue) + # Retro-compatibility: no internal folder saved + # can't spawn meaningful CompatibilityNode with precomputed outputs + # => automatically try to perform node upgrade + if not internalFolder and nodeDesc: + logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name)) + node = node.upgrade() + # If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is + # an "unknown node type" conflict (in which case the upgrade would fail) + elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType: + node = node.upgrade() + + return node diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 47be0a3385..4865ad6f2f 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -7,7 +7,8 @@ from meshroom.core.attribute import ListAttribute, Attribute from meshroom.core.graph import GraphModification -from meshroom.core.node import nodeFactory, Position +from meshroom.core.node import Position +from meshroom.core.nodeFactory import nodeFactory class UndoCommand(QUndoCommand): From d7f403401a3da88532d18f75523cb7bd09872be1 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:11:41 +0100 Subject: [PATCH 02/18] [tests] Add extra compatibility tests Add a new test suite for graph template loading. --- tests/test_compatibility.py | 78 +++++++++++++++++++++++++++++++------ tests/utils.py | 15 +++++++ 2 files changed, 82 insertions(+), 11 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 00505352e7..ac7c3002d8 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -4,6 +4,7 @@ import os import copy +from typing import Type import pytest import meshroom.core @@ -12,6 +13,8 @@ from meshroom.core.graph import Graph, loadGraph from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node +from .utils import registeredNodeTypes + SampleGroupV1 = [ desc.IntParam(name="a", label="a", description="", value=0, range=None), @@ -156,6 +159,12 @@ class SampleInputNodeV2(desc.InputNode): ] + +def replaceNodeTypeDesc(nodeType: str, nodeDesc: Type[desc.Node]): + """Change the `nodeDesc` associated to `nodeType`.""" + meshroom.core.nodesDesc[nodeType] = nodeDesc + + def test_unknown_node_type(): """ Test compatibility behavior for unknown node type. @@ -218,8 +227,7 @@ def test_description_conflict(): g.save(graphFile) # reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances) - g = loadGraph(graphFile) - assert all(isinstance(n, Node) for n in g.nodes) + loadGraph(graphFile, strictCompatibility=True) # offset node types register to create description conflicts # each node type name now reference the next one's implementation @@ -399,20 +407,68 @@ def test_conformUpgrade(): class TestGraphLoadingWithStrictCompatibility: + def test_failsOnUnknownNodeType(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save() + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) + + def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk): - registerNodeType(SampleNodeV1) - registerNodeType(SampleNodeV2) - graph: Graph = graphSavedOnDisk - graph.addNewNode(SampleNodeV1.__name__) - graph.save() + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save() + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) - # Replace saved node description by V2 - meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2 + +class TestGraphTemplateLoading: + + def test_failsOnUnknownNodeTypeError(self, graphSavedOnDisk): + + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save(template=True) with pytest.raises(GraphCompatibilityError): loadGraph(graph.filepath, strictCompatibility=True) - unregisterNodeType(SampleNodeV1) - unregisterNodeType(SampleNodeV2) + def test_loadsIfIncompatibleNodeHasDefaultAttributeValues(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__) + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadGraph(graph.filepath, strictCompatibility=True) + + def test_loadsIfValueSetOnCompatibleAttribute(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + node = graph.addNewNode(SampleNodeV1.__name__, paramA="foo") + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadedGraph = loadGraph(graph.filepath, strictCompatibility=True) + assert loadedGraph.nodes.get(node.name).paramA.value == "foo" + def test_loadsIfValueSetOnIncompatibleAttribute(self, graphSavedOnDisk): + with registeredNodeTypes([SampleNodeV1, SampleNodeV2]): + graph: Graph = graphSavedOnDisk + graph.addNewNode(SampleNodeV1.__name__, input="foo") + graph.save(template=True) + + replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) + + loadGraph(graph.filepath, strictCompatibility=True) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..30745c5f43 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,15 @@ +from contextlib import contextmanager +from typing import Type +from meshroom.core import registerNodeType, unregisterNodeType + +from meshroom.core import desc + +@contextmanager +def registeredNodeTypes(nodeTypes: list[Type[desc.Node]]): + for nodeType in nodeTypes: + registerNodeType(nodeType) + + yield + + for nodeType in nodeTypes: + unregisterNodeType(nodeType) From 5e8e9009c10a2b8342e1d1d1931aba95d715b39b Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:11:41 +0100 Subject: [PATCH 03/18] [core] Refactor `nodeFactory` function Rewrite `nodeFactory` to reduce cognitive complexity, while preserving the current behavior. --- meshroom/core/graph.py | 6 +- meshroom/core/nodeFactory.py | 279 ++++++++++++++++++++++------------- meshroom/ui/commands.py | 3 +- 3 files changed, 185 insertions(+), 103 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 0a672b388c..12cca07cb3 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -333,7 +333,7 @@ def _load(self, filepath, setupProjectFile, importProject, publishOutputs): if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish": continue - n = nodeFactory(nodeData, nodeName, template=isTemplate) + n = nodeFactory(nodeData, nodeName, inTemplate=isTemplate) # Add node to the graph with raw attributes values self._addNode(n, nodeName) @@ -386,14 +386,14 @@ def _evaluateUidConflicts(self, data): # Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode logging.debug("UID conflict detected for {}".format(nodeName)) self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, template=False, uidConflict=True) + n = nodeFactory(nodeData, nodeName, expectedUid=graphUid) self._addNode(n, nodeName) else: # f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost: # the links will be erroneously updated, and any further resolution will fail. # Recreating the entire graph as it was ensures that all edges will be correctly preserved. self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, template=False, uidConflict=False) + n = nodeFactory(nodeData, nodeName) self._addNode(n, nodeName) def updateImportedProject(self, data): diff --git a/meshroom/core/nodeFactory.py b/meshroom/core/nodeFactory.py index 2c54f1b44a..f030b9c5b2 100644 --- a/meshroom/core/nodeFactory.py +++ b/meshroom/core/nodeFactory.py @@ -1,116 +1,197 @@ import logging +from typing import Any, Iterable, Optional, Union import meshroom.core from meshroom.core import Version, desc from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position -def nodeFactory(nodeDict, name=None, template=False, uidConflict=False): +def nodeFactory( + nodeData: dict, + name: Optional[str] = None, + inTemplate: bool = False, + expectedUid: Optional[str] = None, +) -> Union[Node, CompatibilityNode]: """ Create a node instance by deserializing the given node data. If the serialized data matches the corresponding node type description, a Node instance is created. If any compatibility issue occurs, a NodeCompatibility instance is created instead. Args: - nodeDict (dict): the serialization of the node - name (str): (optional) the node's name - template (bool): (optional) true if the node is part of a template, false otherwise - uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node + nodeDict: The serialized Node data. + name: (optional) The node's name. + inTemplate: (optional) True if the node is created as part of a graph template. + expectedUid: (optional) The expected UID of the node within the context of a Graph. Returns: - BaseNode: the created node + The created Node instance. """ - nodeType = nodeDict["nodeType"] - - # Retro-compatibility: inputs were previously saved as "attributes" - if "inputs" not in nodeDict and "attributes" in nodeDict: - nodeDict["inputs"] = nodeDict["attributes"] - del nodeDict["attributes"] - - # Get node inputs/outputs - inputs = nodeDict.get("inputs", {}) - internalInputs = nodeDict.get("internalInputs", {}) - outputs = nodeDict.get("outputs", {}) - version = nodeDict.get("version", None) - internalFolder = nodeDict.get("internalFolder", None) - position = Position(*nodeDict.get("position", [])) - uid = nodeDict.get("uid", None) - - compatibilityIssue = None - - nodeDesc = None - try: - nodeDesc = meshroom.core.nodesDesc[nodeType] - except KeyError: - # Unknown node type - compatibilityIssue = CompatibilityIssue.UnknownNodeType - - # Unknown node type should take precedence over UID conflict, as it cannot be resolved - if uidConflict and nodeDesc: - compatibilityIssue = CompatibilityIssue.UidConflict - - if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue - # Compare serialized node version with current node version - currentNodeVersion = meshroom.core.nodeVersion(nodeDesc) - # If both versions are available, check for incompatibility in major version - if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major: - compatibilityIssue = CompatibilityIssue.VersionConflict - # In other cases, check attributes compatibility between serialized node and its description + return _NodeCreator(nodeData, name, inTemplate, expectedUid).create() + + +class _NodeCreator: + + def __init__( + self, + nodeData: dict, + name: Optional[str] = None, + inTemplate: bool = False, + expectedUid: Optional[str] = None, + ): + self.nodeData = nodeData + self.name = name + self.inTemplate = inTemplate + self.expectedUid = expectedUid + + self._normalizeNodeData() + + self.nodeType = self.nodeData["nodeType"] + self.inputs = self.nodeData.get("inputs", {}) + self.internalInputs = self.nodeData.get("internalInputs", {}) + self.outputs = self.nodeData.get("outputs", {}) + self.version = self.nodeData.get("version", None) + self.internalFolder = self.nodeData.get("internalFolder") + self.position = Position(*self.nodeData.get("position", [])) + self.uid = self.nodeData.get("uid", None) + self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None) + + def create(self) -> Union[Node, CompatibilityNode]: + compatibilityIssue = self._checkCompatibilityIssues() + if compatibilityIssue: + node = self._createCompatibilityNode(compatibilityIssue) + node = self._tryUpgradeCompatibilityNode(node) else: - # Check that the node has the exact same set of inputs/outputs as its description, except - # if the node is described in a template file, in which only non-default parameters are saved; - # do not perform that check for internal attributes because there is no point in - # raising compatibility issues if their number differs: in that case, it is only useful - # if some internal attributes do not exist or are invalid - if not template and (sorted([attr.name for attr in nodeDesc.inputs - if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or - sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != - sorted(outputs.keys())): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - - # Check whether there are any internal attributes that are invalidating in the node description: if there - # are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility - # issue must be raised to warn the user, as this will automatically change the node's UID - if not template: - invalidatingIntInputs = [] - for attr in nodeDesc.internalInputs: - if attr.invalidate: - invalidatingIntInputs.append(attr.name) - for attr in invalidatingIntInputs: - if attr not in internalInputs.keys(): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - # Verify that all inputs match their descriptions - for attrName, value in inputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all internal inputs match their description - for attrName, value in internalInputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - # Verify that all outputs match their descriptions - for attrName, value in outputs.items(): - if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value): - compatibilityIssue = CompatibilityIssue.DescriptionConflict - break - - if compatibilityIssue is None: - node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs) - else: - logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name)) - node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue) - # Retro-compatibility: no internal folder saved - # can't spawn meaningful CompatibilityNode with precomputed outputs - # => automatically try to perform node upgrade - if not internalFolder and nodeDesc: - logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name)) - node = node.upgrade() - # If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is - # an "unknown node type" conflict (in which case the upgrade would fail) - elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType: - node = node.upgrade() - - return node + node = self._createNode() + return node + + def _normalizeNodeData(self): + """Consistency fixes for backward compatibility with older serialized data.""" + # Inputs were previously saved as "attributes". + if "inputs" not in self.nodeData and "attributes" in self.nodeData: + self.nodeData["inputs"] = self.nodeData["attributes"] + del self.nodeData["attributes"] + + def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]: + if self.nodeDesc is None: + return CompatibilityIssue.UnknownNodeType + + if not self._checkUidCompatibility(): + return CompatibilityIssue.UidConflict + + if not self._checkVersionCompatibility(): + return CompatibilityIssue.VersionConflict + + if not self._checkDescriptionCompatibility(): + return CompatibilityIssue.DescriptionConflict + + return None + + def _checkUidCompatibility(self) -> bool: + return self.expectedUid is None or self.expectedUid == self.uid + + def _checkVersionCompatibility(self) -> bool: + # Special case: a node with a version set to None indicates + # that it has been created from the current version of the node type. + nodeCreatedFromCurrentVersion = self.version is None + if nodeCreatedFromCurrentVersion: + return True + nodeTypeCurrentVersion = meshroom.core.nodeVersion(self.nodeDesc, "0.0") + return Version(self.version).major == Version(nodeTypeCurrentVersion).major + + def _checkDescriptionCompatibility(self) -> bool: + # Only perform strict attribute name matching for non-template graphs, + # since only non-default-value input attributes are serialized in templates. + if not self.inTemplate: + if not self._checkAttributesNamesMatchDescription(): + return False + + return self._checkAttributesAreCompatibleWithDescription() + + def _checkAttributesNamesMatchDescription(self) -> bool: + return ( + self._checkInputAttributesNames() + and self._checkOutputAttributesNames() + and self._checkInternalAttributesNames() + ) + + def _checkAttributesAreCompatibleWithDescription(self) -> bool: + return ( + self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs) + and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs) + and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs) + ) + + def _checkInputAttributesNames(self) -> bool: + def serializedInput(attr: desc.Attribute) -> bool: + """Filter that excludes not-serialized desc input attributes.""" + if isinstance(attr, desc.PushButtonParam): + # PushButtonParam are not serialized has they do not hold a value. + return False + return True + + refAttributes = filter(serializedInput, self.nodeDesc.inputs) + return self._checkAttributesNamesStrictlyMatch(refAttributes, self.inputs) + + def _checkOutputAttributesNames(self) -> bool: + def serializedOutput(attr: desc.Attribute) -> bool: + """Filter that excludes not-serialized desc output attributes.""" + if attr.isDynamicValue: + # Dynamic outputs values are not serialized with the node, + # as their value is written in the computed output data. + return False + return True + + refAttributes = filter(serializedOutput, self.nodeDesc.outputs) + return self._checkAttributesNamesStrictlyMatch(refAttributes, self.outputs) + + def _checkInternalAttributesNames(self) -> bool: + invalidatingDescAttributes = [attr.name for attr in self.nodeDesc.internalInputs if attr.invalidate] + return all(attr in self.internalInputs.keys() for attr in invalidatingDescAttributes) + + def _checkAttributesNamesStrictlyMatch( + self, descAttributes: Iterable[desc.Attribute], attributesDict: dict[str, Any] + ) -> bool: + refNames = sorted([attr.name for attr in descAttributes]) + attrNames = sorted(attributesDict.keys()) + return refNames == attrNames + + def _checkAttributesCompatibility( + self, descAttributes: list[desc.Attribute], attributesDict: dict[str, Any] + ) -> bool: + return all( + CompatibilityNode.attributeDescFromName(descAttributes, attrName, value) is not None + for attrName, value in attributesDict.items() + ) + + def _createNode(self) -> Node: + logging.info(f"Creating node '{self.name}'") + return Node( + self.nodeType, + position=self.position, + uid=self.uid, + **self.inputs, + **self.internalInputs, + **self.outputs, + ) + + def _createCompatibilityNode(self, compatibilityIssue) -> CompatibilityNode: + logging.warning(f"Compatibility issue detected for node '{self.name}': {compatibilityIssue.name}") + return CompatibilityNode( + self.nodeType, self.nodeData, position=self.position, issue=compatibilityIssue + ) + + def _tryUpgradeCompatibilityNode(self, node: CompatibilityNode) -> Union[Node, CompatibilityNode]: + """Handle possible upgrades of CompatibilityNodes, when no computed data is associated to the Node.""" + if node.issue == CompatibilityIssue.UnknownNodeType: + return node + + # Nodes in templates are not meant to hold computation data. + if self.inTemplate: + logging.warning(f"Compatibility issue in template: performing automatic upgrade on '{self.name}'") + return node.upgrade() + + # Backward compatibility: "internalFolder" was not serialized. + if not self.internalFolder: + logging.warning(f"No serialized output data: performing automatic upgrade on '{self.name}'") + + return node diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 4865ad6f2f..52e3151e24 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -432,11 +432,12 @@ def redoImpl(self): def undoImpl(self): # delete upgraded node + expectedUid = self.graph.node(self.nodeName)._uid self.graph.removeNode(self.nodeName) # recreate compatibility node with GraphModification(self.graph): # We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before - node = nodeFactory(self.nodeDict, name=self.nodeName, uidConflict=True) + node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid) self.graph.addNode(node, self.nodeName) # recreate out edges for dstAttr, srcAttr in self.outEdges.items(): From e6160bf8a0759d63ef2f6f79b11d83deac0ce2c7 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:00 +0100 Subject: [PATCH 04/18] [core] Graph: initial refactoring of graph loading API and logic * API Instead of having a single `load` function that exposes in its API some elements only applicable to initializing a graph from a templates, split it into 2 distinct functions: `load` and `initFromTemplate`. Apply those changes to users of the API (UI, CLI), and simplify Graph wrapper classes to better align with those concepts. * Deserialization Reduce the cognitive complexity of the deserizalization process by splitting it into more atomic functions, while maintaining the current behavior. --- bin/meshroom_batch | 4 +- meshroom/core/graph.py | 191 ++++++++++++++++++-------------- meshroom/core/typing.py | 8 ++ meshroom/ui/graph.py | 18 +-- meshroom/ui/qml/Application.qml | 6 +- meshroom/ui/qml/Homepage.qml | 2 +- meshroom/ui/qml/main.qml | 2 +- meshroom/ui/reconstruction.py | 45 ++++---- 8 files changed, 151 insertions(+), 125 deletions(-) create mode 100644 meshroom/core/typing.py diff --git a/bin/meshroom_batch b/bin/meshroom_batch index 36b8fef689..6bee4c1f6d 100755 --- a/bin/meshroom_batch +++ b/bin/meshroom_batch @@ -154,10 +154,10 @@ with meshroom.core.graph.GraphModification(graph): # initialize template pipeline loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items()) if args.pipeline.lower() in loweredPipelineTemplates: - graph.load(loweredPipelineTemplates[args.pipeline.lower()], setupProjectFile=False, publishOutputs=True if args.output else False) + graph.initFromTemplate(loweredPipelineTemplates[args.pipeline.lower()], publishOutputs=True if args.output else False) else: # custom pipeline - graph.load(args.pipeline, setupProjectFile=False, publishOutputs=True if args.output else False) + graph.initFromTemplate(args.pipeline, publishOutputs=True if args.output else False) def parseInputs(inputs, uniqueInitNode): """Utility method for parsing the input and inputRecursive arguments.""" diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 12cca07cb3..0910569006 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -4,6 +4,7 @@ import logging import os import re +from typing import Optional import weakref from collections import defaultdict, OrderedDict from contextlib import contextmanager @@ -18,6 +19,7 @@ from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit from meshroom.core.node import Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory +from meshroom.core.typing import PathLike # Replace default encoder to support Enums @@ -149,6 +151,21 @@ def decorator(self, *args, **kwargs): return decorator +def blockNodeCallbacks(func): + """ + Graph methods loading serialized graph content must be decorated with 'blockNodeCallbacks', + to avoid attribute changed callbacks defined on node descriptions to be triggered during + this process. + """ + def inner(self, *args, **kwargs): + self._loading = True + try: + return func(self, *args, **kwargs) + finally: + self._loading = False + return inner + + class Graph(BaseObject): """ _________________ _________________ _________________ @@ -254,116 +271,118 @@ def isLoading(self): return self._loading @Slot(str) - def load(self, filepath, setupProjectFile=True, importProject=False, publishOutputs=False): + def load(self, filepath: PathLike): """ - Load a Meshroom graph ".mg" file. + Load a Meshroom Graph ".mg" file in place. Args: - filepath: project filepath to load - setupProjectFile: Store the reference to the project file and setup the cache directory. - If false, it only loads the graph of the project file as a template. - importProject: True if the project that is loaded will be imported in the current graph, instead - of opened. - publishOutputs: True if "Publish" nodes from templates should not be ignored. + filepath: The path to the Meshroom Graph file to load. """ - self._loading = True - try: - return self._load(filepath, setupProjectFile, importProject, publishOutputs) - finally: - self._loading = False - - def _load(self, filepath, setupProjectFile, importProject, publishOutputs): - if not importProject: - self.clear() - with open(filepath) as jsonFile: - fileData = json.load(jsonFile) - - self.header = fileData.get(Graph.IO.Keys.Header, {}) - - fileVersion = self.header.get(Graph.IO.Keys.FileVersion, "0.0") - # Retro-compatibility for all project files with the previous UID format - if Version(fileVersion) < Version("2.0"): - # For internal folders, all "{uid0}" keys should be replaced with "{uid}" - updatedFileData = json.dumps(fileData).replace("{uid0}", "{uid}") + self._deserialize(Graph._loadGraphData(filepath)) + self._setFilepath(filepath) + self._fileDateVersion = os.path.getmtime(filepath) - # For fileVersion < 2.0, the nodes' UID is stored as: - # "uids": {"0": "hashvalue"} - # These should be identified and replaced with: - # "uid": "hashvalue" - uidPattern = re.compile(r'"uids": \{"0":.*?\}') - uidOccurrences = uidPattern.findall(updatedFileData) - for occ in uidOccurrences: - uid = occ.split("\"")[-2] # UID is second to last element - newUidStr = r'"uid": "{}"'.format(uid) - updatedFileData = updatedFileData.replace(occ, newUidStr) - fileData = json.loads(updatedFileData) + def initFromTemplate(self, filepath: PathLike, publishOutputs: bool = False): + """ + Deserialize a template Meshroom Graph ".mg" file in place. - # Older versions of Meshroom files only contained the serialized nodes - graphData = fileData.get(Graph.IO.Keys.Graph, fileData) + When initializing from a template, the internal filepath of the graph instance is not set. + Saving the file on disk will require to specify a filepath. - if importProject: - self._importedNodes.clear() - graphData = self.updateImportedProject(graphData) + Args: + filepath: The path to the Meshroom Graph file to load. + publishOutputs: (optional) Whether to keep 'Publish' nodes. + """ + self._deserialize(Graph._loadGraphData(filepath)) - if not isinstance(graphData, dict): - raise RuntimeError('loadGraph error: Graph is not a dict. File: {}'.format(filepath)) + if not publishOutputs: + for node in [node for node in self.nodes if node.nodeType == "Publish"]: + self.removeNode(node.name) - nodesVersions = self.header.get(Graph.IO.Keys.NodesVersions, {}) + @staticmethod + def _loadGraphData(filepath: PathLike) -> dict: + """Deserialize the content of the Meshroom Graph file at `filepath` to a dictionnary.""" + with open(filepath) as file: + graphData = json.load(file) + return graphData - self._fileDateVersion = os.path.getmtime(filepath) + @blockNodeCallbacks + def _deserialize(self, graphData: dict): + """Deserialize `graphData` in the current Graph instance. - # Check whether the file was saved as a template in minimal mode + Args: + graphData: The serialized Graph. + """ + self.clear() + self.header = graphData.get(Graph.IO.Keys.Header, {}) + fileVersion = Version(self.header.get(Graph.IO.Keys.FileVersion, "0.0")) + graphContent = self._normalizeGraphContent(graphData, fileVersion) isTemplate = self.header.get("template", False) with GraphModification(self): # iterate over nodes sorted by suffix index in their names - for nodeName, nodeData in sorted(graphData.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - if not isinstance(nodeData, dict): - raise RuntimeError('loadGraph error: Node is not a dict. File: {}'.format(filepath)) - - # retrieve version from - # 1. nodeData: node saved from a CompatibilityNode - # 2. nodesVersion in file header: node saved from a Node - # 3. fallback to no version "0.0": retro-compatibility - if "version" not in nodeData: - nodeData["version"] = nodesVersions.get(nodeData["nodeType"], "0.0") - - # if the node is a "Publish" node and comes from a template file, it should be ignored - # unless publishOutputs is True - if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish": - continue - - n = nodeFactory(nodeData, nodeName, inTemplate=isTemplate) - - # Add node to the graph with raw attributes values - self._addNode(n, nodeName) - - if importProject: - self._importedNodes.add(n) + for nodeName, nodeData in sorted( + graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0]) + ): + self._deserializeNode(nodeData, nodeName) # Create graph edges by resolving attributes expressions self._applyExpr() - - if setupProjectFile: - # Update filepath related members - # Note: needs to be done at the end as it will trigger an updateInternals. - self._setFilepath(filepath) - elif not isTemplate: - # If no filepath is being set but the graph is not a template, trigger an updateInternals either way. - self.updateInternals() + + # Templates are specific: they contain only the minimal amount of + # serialized data to describe the graph structure. + # They are not meant to be computed: therefore, we can early return here, + # as uid conflict evaluation is only meaningful for nodes with computed data. + if isTemplate: + return # By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the # nodes' links have been resolved and their UID computations are all complete. # It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones # that were computed. - if not isTemplate: # UIDs are not stored in templates - self._evaluateUidConflicts(graphData) - try: - self._applyExpr() - except Exception as e: - logging.warning(e) + self.updateInternals() + self._evaluateUidConflicts(graphContent) + try: + self._applyExpr() + except Exception as e: + logging.warning(e) + + def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict: + graphContent = graphData.get(Graph.IO.Keys.Graph, graphData) + + if fileVersion < Version("2.0"): + # For internal folders, all "{uid0}" keys should be replaced with "{uid}" + updatedFileData = json.dumps(graphContent).replace("{uid0}", "{uid}") + + # For fileVersion < 2.0, the nodes' UID is stored as: + # "uids": {"0": "hashvalue"} + # These should be identified and replaced with: + # "uid": "hashvalue" + uidPattern = re.compile(r'"uids": \{"0":.*?\}') + uidOccurrences = uidPattern.findall(updatedFileData) + for occ in uidOccurrences: + uid = occ.split("\"")[-2] # UID is second to last element + newUidStr = r'"uid": "{}"'.format(uid) + updatedFileData = updatedFileData.replace(occ, newUidStr) + graphContent = json.loads(updatedFileData) + + return graphContent + + def _deserializeNode(self, nodeData: dict, nodeName: str): + # Retrieve version from + # 1. nodeData: node saved from a CompatibilityNode + # 2. nodesVersion in file header: node saved from a Node + # 3. fallback behavior: default to "0.0" + if "version" not in nodeData: + nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") + inTemplate = self.header.get("template", False) + node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) + self._addNode(node, nodeName) + return node - return True + def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]: + nodeVersions = self.header.get(Graph.IO.Keys.NodesVersions, {}) + return nodeVersions.get(nodeType, default) def _evaluateUidConflicts(self, data): """ diff --git a/meshroom/core/typing.py b/meshroom/core/typing.py new file mode 100644 index 0000000000..f526fb3e3d --- /dev/null +++ b/meshroom/core/typing.py @@ -0,0 +1,8 @@ +""" +Common typing aliases used in Meshroom. +""" + +from pathlib import Path +from typing import Union + +PathLike = Union[Path, str] diff --git a/meshroom/ui/graph.py b/meshroom/ui/graph.py index 0cde6b2ddd..2b69981616 100644 --- a/meshroom/ui/graph.py +++ b/meshroom/ui/graph.py @@ -451,17 +451,21 @@ def stopChildThreads(self): self.stopExecution() self._chunksMonitor.stop() - @Slot(str, result=bool) - def loadGraph(self, filepath, setupProjectFile=True, publishOutputs=False): - g = Graph('') - status = True + @Slot(str) + def loadGraph(self, filepath): + g = Graph("") if filepath: - status = g.load(filepath, setupProjectFile, importProject=False, publishOutputs=publishOutputs) + g.load(filepath) if not os.path.exists(g.cacheDir): os.mkdir(g.cacheDir) - g.fileDateVersion = os.path.getmtime(filepath) self.setGraph(g) - return status + + @Slot(str, bool, result=bool) + def initFromTemplate(self, filepath, publishOutputs=False): + graph = Graph("") + if filepath: + graph.initFromTemplate(filepath, publishOutputs=publishOutputs) + self.setGraph(graph) @Slot(QUrl, result="QVariantList") @Slot(QUrl, QPoint, result="QVariantList") diff --git a/meshroom/ui/qml/Application.qml b/meshroom/ui/qml/Application.qml index 48884e2f33..fb5fc67c67 100644 --- a/meshroom/ui/qml/Application.qml +++ b/meshroom/ui/qml/Application.qml @@ -141,7 +141,7 @@ Page { nameFilters: ["Meshroom Graphs (*.mg)"] onAccepted: { // Open the template as a regular file - if (_reconstruction.loadUrl(currentFile, true, true)) { + if (_reconstruction.load(currentFile)) { MeshroomApp.addRecentProjectFile(currentFile.toString()) } } @@ -356,7 +356,7 @@ Page { text: "Reload File" onClicked: { - _reconstruction.loadUrl(_reconstruction.graph.filepath) + _reconstruction.load(_reconstruction.graph.filepath) fileModifiedDialog.close() } } @@ -661,7 +661,7 @@ Page { MenuItem { onTriggered: ensureSaved(function() { openRecentMenu.dismiss() - if (_reconstruction.loadUrl(modelData["path"])) { + if (_reconstruction.load(modelData["path"])) { MeshroomApp.addRecentProjectFile(modelData["path"]) } else { MeshroomApp.removeRecentProjectFile(modelData["path"]) diff --git a/meshroom/ui/qml/Homepage.qml b/meshroom/ui/qml/Homepage.qml index ef27a22dac..fe7f9ff4a8 100644 --- a/meshroom/ui/qml/Homepage.qml +++ b/meshroom/ui/qml/Homepage.qml @@ -384,7 +384,7 @@ Page { } else { // Open project mainStack.push("Application.qml") - if (_reconstruction.loadUrl(modelData["path"])) { + if (_reconstruction.load(modelData["path"])) { MeshroomApp.addRecentProjectFile(modelData["path"]) } else { MeshroomApp.removeRecentProjectFile(modelData["path"]) diff --git a/meshroom/ui/qml/main.qml b/meshroom/ui/qml/main.qml index 16940a74c5..20c2f81fa1 100644 --- a/meshroom/ui/qml/main.qml +++ b/meshroom/ui/qml/main.qml @@ -128,7 +128,7 @@ ApplicationWindow { if (mainStack.currentItem instanceof Homepage) { mainStack.push("Application.qml") } - if (_reconstruction.loadUrl(currentFile)) { + if (_reconstruction.load(currentFile)) { MeshroomApp.addRecentProjectFile(currentFile.toString()) } } diff --git a/meshroom/ui/reconstruction.py b/meshroom/ui/reconstruction.py index c774527f34..94d926a0c1 100755 --- a/meshroom/ui/reconstruction.py +++ b/meshroom/ui/reconstruction.py @@ -5,6 +5,7 @@ from collections.abc import Iterable from multiprocessing.pool import ThreadPool from threading import Thread +from typing import Callable from PySide6.QtCore import QObject, Slot, Property, Signal, QUrl, QSizeF, QPoint from PySide6.QtGui import QMatrix4x4, QMatrix3x3, QQuaternion, QVector3D, QVector2D @@ -534,17 +535,24 @@ def new(self, pipeline=None): # - correct pipeline name but the case does not match (e.g. panoramaHDR instead of panoramaHdr) # - lowercase pipeline name given through the "New Pipeline" menu loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items()) - if p.lower() in loweredPipelineTemplates: - self.load(loweredPipelineTemplates[p.lower()], setupProjectFile=False) - else: - # use the user-provided default project file - self.load(p, setupProjectFile=False) + filepath = loweredPipelineTemplates.get(p.lower(), p) + return self._loadWithErrorReport(self.initFromTemplate, filepath) @Slot(str, result=bool) - def load(self, filepath, setupProjectFile=True, publishOutputs=False): + @Slot(QUrl, result=bool) + def load(self, url): + if isinstance(url, QUrl): + # depending how the QUrl has been initialized, + # toLocalFile() may return the local path or an empty string + localFile = url.toLocalFile() or url.toString() + else: + localFile = url + return self._loadWithErrorReport(self.loadGraph, localFile) + + def _loadWithErrorReport(self, loadFunction: Callable[[str], None], filepath: str): logging.info(f"Load project file: '{filepath}'") try: - status = super(Reconstruction, self).loadGraph(filepath, setupProjectFile, publishOutputs) + loadFunction(filepath) # warn about pre-release projects being automatically upgraded if Version(self._graph.fileReleaseVersion).major == "0": self.warning.emit(Message( @@ -554,8 +562,8 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): "Open it with the corresponding version of Meshroom to recover your data." )) self.setActive(True) - return status - except FileNotFoundError as e: + return True + except FileNotFoundError: self.error.emit( Message( "No Such File", @@ -564,8 +572,7 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): ) ) logging.error("Error while loading '{}': No Such File.".format(filepath)) - return False - except Exception as e: + except Exception: import traceback trace = traceback.format_exc() self.error.emit( @@ -577,20 +584,8 @@ def load(self, filepath, setupProjectFile=True, publishOutputs=False): ) logging.error("Error while loading '{}'.".format(filepath)) logging.error(trace) - return False - @Slot(QUrl, result=bool) - @Slot(QUrl, bool, bool, result=bool) - def loadUrl(self, url, setupProjectFile=True, publishOutputs=False): - if isinstance(url, (QUrl)): - # depending how the QUrl has been initialized, - # toLocalFile() may return the local path or an empty string - localFile = url.toLocalFile() - if not localFile: - localFile = url.toString() - else: - localFile = url - return self.load(localFile, setupProjectFile, publishOutputs) + return False def onGraphChanged(self): """ React to the change of the internal graph. """ @@ -860,7 +855,7 @@ def handleFilesUrl(self, filesByType, cameraInit=None, position=None): ) ) else: - return self.loadUrl(filesByType["meshroomScenes"][0]) + return self.load(filesByType["meshroomScenes"][0]) From 3fdb91d9e5a0c5fe3a289b914378b19e04790526 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 05/18] [core] Graph: add importGraphContent API Extract the logic of importing the content of a graph within a graph instance from the graph loading logic. Add `Graph.importGraphContent` and `Graph.importGraphContentFromFile` methods. Use the deserialization API to load the content in another temporary graph instance, to handle the renaming of nodes using the Graph API, rather than manipulating entries in a raw dictionnary. --- meshroom/core/graph.py | 96 ++++++++++--- meshroom/ui/commands.py | 11 +- tests/test_graphIO.py | 152 +++++++++++++++++++++ tests/test_nodeAttributeChangedCallback.py | 25 ++++ 4 files changed, 258 insertions(+), 26 deletions(-) create mode 100644 tests/test_graphIO.py diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 0910569006..eed53c06eb 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -242,7 +242,6 @@ def __init__(self, name, parent=None): self._nodes = DictModel(keyAttrName='name', parent=self) # Edges: use dst attribute as unique key since it can only have one input connection self._edges = DictModel(keyAttrName='dst', parent=self) - self._importedNodes = DictModel(keyAttrName='name', parent=self) self._compatibilityNodes = DictModel(keyAttrName='name', parent=self) self.cacheDir = meshroom.core.defaultCacheFolder self._filepath = '' @@ -250,15 +249,17 @@ def __init__(self, name, parent=None): self.header = {} def clear(self): + self._clearGraphContent() self.header.clear() - self._compatibilityNodes.clear() + self._unsetFilepath() + + def _clearGraphContent(self): self._edges.clear() # Tell QML nodes are going to be deleted for node in self._nodes: node.alive = False - self._importedNodes.clear() self._nodes.clear() - self._unsetFilepath() + self._compatibilityNodes.clear() @property def fileFeatures(self): @@ -324,7 +325,7 @@ def _deserialize(self, graphData: dict): for nodeName, nodeData in sorted( graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0]) ): - self._deserializeNode(nodeData, nodeName) + self._deserializeNode(nodeData, nodeName, self) # Create graph edges by resolving attributes expressions self._applyExpr() @@ -368,14 +369,14 @@ def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict: return graphContent - def _deserializeNode(self, nodeData: dict, nodeName: str): + def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"): # Retrieve version from # 1. nodeData: node saved from a CompatibilityNode # 2. nodesVersion in file header: node saved from a Node # 3. fallback behavior: default to "0.0" if "version" not in nodeData: - nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") - inTemplate = self.header.get("template", False) + nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") + inTemplate = fromGraph.header.get("template", False) node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) self._addNode(node, nodeName) return node @@ -549,6 +550,58 @@ def resetExternalLinks(attributes, nodeDesc, newNames): return attributes + def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]: + """Import the content (nodes and edges) of another Graph file into this Graph instance. + + Args: + filepath: The path to the Graph file to import. + + Returns: + The list of newly created Nodes. + """ + graph = loadGraph(filepath) + return self.importGraphContent(graph) + + @blockNodeCallbacks + def importGraphContent(self, graph: "Graph") -> list[Node]: + """ + Import the content (node and edges) of another `graph` into this Graph instance. + + Nodes are imported with their original names if possible, otherwise a new unique name is generated + from their node type. + + Args: + graph: The graph to import. + + Returns: + The list of newly created Nodes. + """ + + def _renameClashingNodes(): + if not self.nodes: + return + unavailableNames = set(self.nodes.keys()) + for node in graph.nodes: + if node._name in unavailableNames: + node._name = self._createUniqueNodeName(node.nodeType, unavailableNames) + unavailableNames.add(node._name) + + def _importNodeAndEdges() -> list[Node]: + importedNodes = [] + # If we import the content of the graph within itself, + # iterate over a copy of the nodes as the graph is modified during the iteration. + nodes = graph.nodes if graph is not self else list(graph.nodes) + with GraphModification(self): + for srcNode in nodes: + node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph) + importedNodes.append(node) + self._applyExpr() + return importedNodes + + _renameClashingNodes() + importedNodes = _importNodeAndEdges() + return importedNodes + @property def updateEnabled(self): return self._updateEnabled @@ -760,8 +813,6 @@ def removeNode(self, nodeName): node.alive = False self._nodes.remove(node) - if node in self._importedNodes: - self._importedNodes.remove(node) self.update() return inEdges, outEdges, outListAttributes @@ -786,13 +837,21 @@ def addNewNode(self, nodeType, name=None, position=None, **kwargs): n.updateInternals() return n - def _createUniqueNodeName(self, inputName): - i = 1 - while i: - newName = "{name}_{index}".format(name=inputName, index=i) - if newName not in self._nodes.objects: + def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None): + """Create a unique node name based on the input name. + + Args: + inputName: The desired node name. + existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes. + """ + existingNodeNames = existingNames or set(self._nodes.objects.keys()) + + idx = 1 + while idx: + newName = f"{inputName}_{idx}" + if newName not in existingNodeNames: return newName - i += 1 + idx += 1 def node(self, nodeName): return self._nodes.get(nodeName) @@ -1612,11 +1671,6 @@ def nodes(self): def edges(self): return self._edges - @property - def importedNodes(self): - """" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """ - return self._importedNodes - @property def cacheDir(self): return self._cacheDir diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 52e3151e24..7d8ccc1f52 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -6,9 +6,10 @@ from PySide6.QtCore import Property, Signal from meshroom.core.attribute import ListAttribute, Attribute -from meshroom.core.graph import GraphModification +from meshroom.core.graph import Graph, GraphModification from meshroom.core.node import Position from meshroom.core.nodeFactory import nodeFactory +from meshroom.core.typing import PathLike class UndoCommand(QUndoCommand): @@ -232,7 +233,8 @@ class ImportProjectCommand(GraphCommand): """ Handle the import of a project into a Graph. """ - def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None): + + def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None): super(ImportProjectCommand, self).__init__(graph, parent) self.filepath = filepath self.importedNames = [] @@ -240,9 +242,8 @@ def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None): self.yOffset = yOffset def redoImpl(self): - status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True) - importedNodes = self.graph.importedNodes - self.setText("Import Project ({} nodes)".format(importedNodes.count)) + importedNodes = self.graph.importGraphContentFromFile(self.filepath) + self.setText(f"Import Project ({len(importedNodes)} nodes)") lowestY = 0 for node in self.graph.nodes: diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py new file mode 100644 index 0000000000..01d3a89eeb --- /dev/null +++ b/tests/test_graphIO.py @@ -0,0 +1,152 @@ +from meshroom.core import desc +from meshroom.core.graph import Graph + +from .utils import registeredNodeTypes + + +class SimpleNode(desc.Node): + inputs = [ + desc.File(name="input", label="Input", description="", value=""), + ] + outputs = [ + desc.File(name="output", label="Output", description="", value=""), + ] + + +def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool: + """Returns whether the content (node and deges) of two graphs are considered identical. + + Similar nodes: nodes with the same name, type and compatibility status. + Similar edges: edges with the same source and destination attribute names. + """ + + def _buildNodesSet(graph: Graph): + return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes]) + + def _buildEdgesSet(graph: Graph): + return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges]) + + return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet( + graphB + ) + + +class TestImportGraphContent: + def test_importEmptyGraph(self): + graph = Graph("") + + otherGraph = Graph("") + nodes = otherGraph.importGraphContent(graph) + + assert len(nodes) == 0 + assert len(graph.nodes) == 0 + + def test_importGraphWithSingleNode(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphWithSeveralNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importingGraphWithNodesAndEdges(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + assert compareGraphsContent(graph, otherGraph) + + def test_edgeRemappingOnImportingGraphSeveralTimes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + otherGraph.importGraphContent(graph) + + def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + importedNode = otherGraph.importGraphContent(graph) + + assert len(importedNode) == 1 + assert importedNode[0].isCompatibilityNode + + def test_importGraphContentInPlace(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + graph.importGraphContent(graph) + + assert len(graph.nodes) == 4 + + def test_importGraphContentFromFile(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + assert len(otherGraph.compatibilityNodes) == 2 + assert not compareGraphsContent(graph, otherGraph) + + diff --git a/tests/test_nodeAttributeChangedCallback.py b/tests/test_nodeAttributeChangedCallback.py index edd14bc8dc..faee0e00ba 100644 --- a/tests/test_nodeAttributeChangedCallback.py +++ b/tests/test_nodeAttributeChangedCallback.py @@ -431,3 +431,28 @@ def test_loadingGraphWithComputedDynamicOutputValueDoesNotTriggerDownstreamAttri assert nodeB.affectedInput.value == 0 +class TestAttributeCallbackBehaviorOnGraphImport: + @classmethod + def setup_class(cls): + registerNodeType(NodeWithAttributeChangedCallback) + + @classmethod + def teardown_class(cls): + unregisterNodeType(NodeWithAttributeChangedCallback) + + def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self): + graph = Graph("") + + nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + + graph.addEdge(nodeA.affectedInput, nodeB.input) + + nodeA.input.value = 5 + nodeB.affectedInput.value = 2 + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert otherGraph.node(nodeB.name).affectedInput.value == 2 + From 035625fd0794bb7457d5135d29d4ac38aaef5e12 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 06/18] [core] CompatibilityNode: do not use link expressions as default values for unknown File attributes When creating a compatibility description for an unknown attribute, do not consider a link expression as the default value for a File attribute. This is breaking how the link expression solving system works, as it's resetting the attribute to its default value after applying the link. If that expression is kept as the default value, it can be re-evaluated several times incorrectly. Added a test case that was failing before that change to illustrate the issue. --- meshroom/core/node.py | 12 +++++++++++- tests/test_graphIO.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/meshroom/core/node.py b/meshroom/core/node.py index a4c7f76ddf..6f42272481 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -1668,7 +1668,17 @@ def attributeDescFromValue(attrName, value, isOutput): elif isinstance(value, float): return desc.FloatParam(range=None, **params) elif isinstance(value, str): - if isOutput or os.path.isabs(value) or Attribute.isLinkExpression(value): + if isOutput or os.path.isabs(value): + return desc.File(**params) + elif Attribute.isLinkExpression(value): + # Do not consider link expression as a valid default desc value. + # When the link expression is applied and transformed to an actual link, + # the systems resets the value using `Attribute.resetToDefaultValue` to indicate + # that this link expression has been handled. + # If the link expression is stored as the default value, it will never be cleared, + # leading to unexpected behavior where the link expression on a CompatibilityNode + # could be evaluated several times and/or incorrectly. + params["value"] = "" return desc.File(**params) else: return desc.StringParam(**params) diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py index 01d3a89eeb..b374275908 100644 --- a/tests/test_graphIO.py +++ b/tests/test_graphIO.py @@ -90,6 +90,23 @@ def test_edgeRemappingOnImportingGraphSeveralTimes(self): otherGraph.importGraphContent(graph) otherGraph.importGraphContent(graph) + def test_edgeRemappingOnImportingGraphWithUnkownNodeTypesSeveralTimes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + otherGraph.importGraphContent(graph) + + assert len(otherGraph.nodes) == 4 + assert len(otherGraph.compatibilityNodes) == 4 + assert len(otherGraph.edges) == 2 + def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self): graph = Graph("") From 7465823423c42440482c4e5cf9ad7f33192b0a7f Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 07/18] [core] Introducing new graphIO module Move Graph.IO internal class to its own module, and rename it to `GraphIO`. This avoid nested classes within the core Graph class, and starts decoupling the management of graph's IO from the logic of the graph itself. --- meshroom/core/graph.py | 73 +++++++--------------------------- meshroom/core/graphIO.py | 56 ++++++++++++++++++++++++++ meshroom/ui/graph.py | 3 +- tests/test_templatesVersion.py | 7 ++-- 4 files changed, 76 insertions(+), 63 deletions(-) create mode 100644 meshroom/core/graphIO.py diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index eed53c06eb..9c9ba8eebb 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -17,6 +17,7 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit +from meshroom.core.graphIO import GraphIO from meshroom.core.node import Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -183,52 +184,6 @@ class Graph(BaseObject): """ _cacheDir = "" - class IO(object): - """ Centralize Graph file keys and IO version. """ - __version__ = "2.0" - - class Keys(object): - """ File Keys. """ - # Doesn't inherit enum to simplify usage (Graph.IO.Keys.XX, without .value) - Header = "header" - NodesVersions = "nodesVersions" - ReleaseVersion = "releaseVersion" - FileVersion = "fileVersion" - Graph = "graph" - - class Features(Enum): - """ File Features. """ - Graph = "graph" - Header = "header" - NodesVersions = "nodesVersions" - PrecomputedOutputs = "precomputedOutputs" - NodesPositions = "nodesPositions" - - @staticmethod - def getFeaturesForVersion(fileVersion): - """ Return the list of supported features based on a file version. - - Args: - fileVersion (str, Version): the file version - - Returns: - tuple of Graph.IO.Features: the list of supported features - """ - if isinstance(fileVersion, str): - fileVersion = Version(fileVersion) - - features = [Graph.IO.Features.Graph] - if fileVersion >= Version("1.0"): - features += [Graph.IO.Features.Header, - Graph.IO.Features.NodesVersions, - Graph.IO.Features.PrecomputedOutputs, - ] - - if fileVersion >= Version("1.1"): - features += [Graph.IO.Features.NodesPositions] - - return tuple(features) - def __init__(self, name, parent=None): super(Graph, self).__init__(parent) self.name = name @@ -264,7 +219,7 @@ def _clearGraphContent(self): @property def fileFeatures(self): """ Get loaded file supported features based on its version. """ - return Graph.IO.getFeaturesForVersion(self.header.get(Graph.IO.Keys.FileVersion, "0.0")) + return GraphIO.getFeaturesForVersion(self.header.get(GraphIO.Keys.FileVersion, "0.0")) @property def isLoading(self): @@ -315,8 +270,8 @@ def _deserialize(self, graphData: dict): graphData: The serialized Graph. """ self.clear() - self.header = graphData.get(Graph.IO.Keys.Header, {}) - fileVersion = Version(self.header.get(Graph.IO.Keys.FileVersion, "0.0")) + self.header = graphData.get(GraphIO.Keys.Header, {}) + fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0")) graphContent = self._normalizeGraphContent(graphData, fileVersion) isTemplate = self.header.get("template", False) @@ -349,7 +304,7 @@ def _deserialize(self, graphData: dict): logging.warning(e) def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict: - graphContent = graphData.get(Graph.IO.Keys.Graph, graphData) + graphContent = graphData.get(GraphIO.Keys.Graph, graphData) if fileVersion < Version("2.0"): # For internal folders, all "{uid0}" keys should be replaced with "{uid}" @@ -382,7 +337,7 @@ def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"): return node def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]: - nodeVersions = self.header.get(Graph.IO.Keys.NodesVersions, {}) + nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {}) return nodeVersions.get(nodeType, default) def _evaluateUidConflicts(self, data): @@ -1430,8 +1385,8 @@ def save(self, filepath=None, setupProjectFile=True, template=False): if not path: raise ValueError("filepath must be specified for unsaved files.") - self.header[Graph.IO.Keys.ReleaseVersion] = meshroom.__version__ - self.header[Graph.IO.Keys.FileVersion] = Graph.IO.__version__ + self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ + self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__ # Store versions of node types present in the graph (excluding CompatibilityNode instances) # and remove duplicates @@ -1444,19 +1399,19 @@ def save(self, filepath=None, setupProjectFile=True, template=False): # Sort them by name (to avoid random order changing from one save to another) nodesVersions = dict(sorted(nodesVersions.items())) # Add it the header - self.header[Graph.IO.Keys.NodesVersions] = nodesVersions + self.header[GraphIO.Keys.NodesVersions] = nodesVersions self.header["template"] = template data = {} if template: data = { - Graph.IO.Keys.Header: self.header, - Graph.IO.Keys.Graph: self.getNonDefaultInputAttributes() + GraphIO.Keys.Header: self.header, + GraphIO.Keys.Graph: self.getNonDefaultInputAttributes() } else: data = { - Graph.IO.Keys.Header: self.header, - Graph.IO.Keys.Graph: self.toDict() + GraphIO.Keys.Header: self.header, + GraphIO.Keys.Graph: self.toDict() } with open(path, 'w') as jsonFile: @@ -1710,7 +1665,7 @@ def setVerbose(self, v): edges = Property(BaseObject, edges.fget, constant=True) filepathChanged = Signal() filepath = Property(str, lambda self: self._filepath, notify=filepathChanged) - fileReleaseVersion = Property(str, lambda self: self.header.get(Graph.IO.Keys.ReleaseVersion, "0.0"), + fileReleaseVersion = Property(str, lambda self: self.header.get(GraphIO.Keys.ReleaseVersion, "0.0"), notify=filepathChanged) fileDateVersion = Property(float, fileDateVersion.fget, fileDateVersion.fset, notify=filepathChanged) cacheDirChanged = Signal() diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py new file mode 100644 index 0000000000..b7f7ad5a12 --- /dev/null +++ b/meshroom/core/graphIO.py @@ -0,0 +1,56 @@ +from enum import Enum +from typing import Union + +from meshroom.core import Version + + +class GraphIO: + """Centralize Graph file keys and IO version.""" + + __version__ = "2.0" + + class Keys(object): + """File Keys.""" + + # Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value) + Header = "header" + NodesVersions = "nodesVersions" + ReleaseVersion = "releaseVersion" + FileVersion = "fileVersion" + Graph = "graph" + + class Features(Enum): + """File Features.""" + + Graph = "graph" + Header = "header" + NodesVersions = "nodesVersions" + PrecomputedOutputs = "precomputedOutputs" + NodesPositions = "nodesPositions" + + @staticmethod + def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]: + """Return the list of supported features based on a file version. + + Args: + fileVersion (str, Version): the file version + + Returns: + tuple of GraphIO.Features: the list of supported features + """ + if isinstance(fileVersion, str): + fileVersion = Version(fileVersion) + + features = [GraphIO.Features.Graph] + if fileVersion >= Version("1.0"): + features += [ + GraphIO.Features.Header, + GraphIO.Features.NodesVersions, + GraphIO.Features.PrecomputedOutputs, + ] + + if fileVersion >= Version("1.1"): + features += [GraphIO.Features.NodesPositions] + + return tuple(features) + diff --git a/meshroom/ui/graph.py b/meshroom/ui/graph.py index 2b69981616..3727fc8ee4 100644 --- a/meshroom/ui/graph.py +++ b/meshroom/ui/graph.py @@ -25,6 +25,7 @@ from meshroom.common.qt import QObjectListModel from meshroom.core.attribute import Attribute, ListAttribute from meshroom.core.graph import Graph, Edge +from meshroom.core.graphIO import GraphIO from meshroom.core.taskManager import TaskManager @@ -396,7 +397,7 @@ def setGraph(self, g): self.updateChunks() # perform auto-layout if graph does not provide nodes positions - if Graph.IO.Features.NodesPositions not in self._graph.fileFeatures: + if GraphIO.Features.NodesPositions not in self._graph.fileFeatures: self._layout.reset() # clear undo-stack after layout self._undoStack.clear() diff --git a/tests/test_templatesVersion.py b/tests/test_templatesVersion.py index 402a228ac4..eb23628f72 100644 --- a/tests/test_templatesVersion.py +++ b/tests/test_templatesVersion.py @@ -4,6 +4,7 @@ from meshroom.core.graph import Graph from meshroom.core import pipelineTemplates, Version from meshroom.core.node import CompatibilityIssue, CompatibilityNode +from meshroom.core.graphIO import GraphIO import json import meshroom @@ -24,13 +25,13 @@ def test_templateVersions(): with open(path) as jsonFile: fileData = json.load(jsonFile) - graphData = fileData.get(Graph.IO.Keys.Graph, fileData) + graphData = fileData.get(GraphIO.Keys.Graph, fileData) assert isinstance(graphData, dict) - header = fileData.get(Graph.IO.Keys.Header, {}) + header = fileData.get(GraphIO.Keys.Header, {}) assert header.get("template", False) - nodesVersions = header.get(Graph.IO.Keys.NodesVersions, {}) + nodesVersions = header.get(GraphIO.Keys.NodesVersions, {}) for _, nodeData in graphData.items(): nodeType = nodeData["nodeType"] From f7ae76dc694a95550228eb977f53e067b53c08da Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 08/18] [graphIO] Introduce graph serializer classes Move the serialization logic to dedicated serializer classes. Implement both `GraphSerializer` and `TemplateGraphSerializer` to cover for the existing serialization use-cases. --- meshroom/core/graph.py | 88 +++++------------------------ meshroom/core/graphIO.py | 116 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 76 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 9c9ba8eebb..0c09255118 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -17,7 +17,7 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit -from meshroom.core.graphIO import GraphIO +from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer from meshroom.core.node import Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -1380,39 +1380,24 @@ def toDict(self): def asString(self): return str(self.toDict()) + def serialize(self, asTemplate: bool = False) -> dict: + """Serialize this Graph instance. + + Args: + asTemplate: Whether to use the template serialization. + + Returns: + The serialized graph data. + """ + SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer + return SerializerClass(self).serialize() + def save(self, filepath=None, setupProjectFile=True, template=False): path = filepath or self._filepath if not path: raise ValueError("filepath must be specified for unsaved files.") - self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ - self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__ - - # Store versions of node types present in the graph (excluding CompatibilityNode instances) - # and remove duplicates - usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)]) - # Convert to node types to "name: version" - nodesVersions = { - "{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0") - for p in usedNodeTypes - } - # Sort them by name (to avoid random order changing from one save to another) - nodesVersions = dict(sorted(nodesVersions.items())) - # Add it the header - self.header[GraphIO.Keys.NodesVersions] = nodesVersions - self.header["template"] = template - - data = {} - if template: - data = { - GraphIO.Keys.Header: self.header, - GraphIO.Keys.Graph: self.getNonDefaultInputAttributes() - } - else: - data = { - GraphIO.Keys.Header: self.header, - GraphIO.Keys.Graph: self.toDict() - } + data = self.serialize(template) with open(path, 'w') as jsonFile: json.dump(data, jsonFile, indent=4) @@ -1423,51 +1408,6 @@ def save(self, filepath=None, setupProjectFile=True, template=False): # update the file date version self._fileDateVersion = os.path.getmtime(path) - def getNonDefaultInputAttributes(self): - """ - Instead of getting all the inputs and internal attribute keys, only get the keys of - the attributes whose value is not the default one. - The output attributes, UIDs, parallelization parameters and internal folder are - not relevant for templates, so they are explicitly removed from the returned dictionary. - - Returns: - dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder - and input/internal attributes with default values removed - """ - graph = self.toDict() - for nodeName in graph.keys(): - node = self.node(nodeName) - - inputKeys = list(graph[nodeName]["inputs"].keys()) - - internalInputKeys = [] - internalInputs = graph[nodeName].get("internalInputs", None) - if internalInputs: - internalInputKeys = list(internalInputs.keys()) - - for attrName in inputKeys: - attribute = node.attribute(attrName) - # check that attribute is not a link for choice attributes - if attribute.isDefault and not attribute.isLink: - del graph[nodeName]["inputs"][attrName] - - for attrName in internalInputKeys: - attribute = node.internalAttribute(attrName) - # check that internal attribute is not a link for choice attributes - if attribute.isDefault and not attribute.isLink: - del graph[nodeName]["internalInputs"][attrName] - - # If all the internal attributes are set to their default values, remove the entry - if len(graph[nodeName]["internalInputs"]) == 0: - del graph[nodeName]["internalInputs"] - - del graph[nodeName]["outputs"] - del graph[nodeName]["uid"] - del graph[nodeName]["internalFolder"] - del graph[nodeName]["parallelization"] - - return graph - def _setFilepath(self, filepath): """ Set the internal filepath of this Graph. diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py index b7f7ad5a12..bc65629212 100644 --- a/meshroom/core/graphIO.py +++ b/meshroom/core/graphIO.py @@ -1,7 +1,12 @@ from enum import Enum -from typing import Union +from typing import Any, TYPE_CHECKING, Union +import meshroom from meshroom.core import Version +from meshroom.core.node import Node + +if TYPE_CHECKING: + from meshroom.core.graph import Graph class GraphIO: @@ -29,7 +34,7 @@ class Features(Enum): NodesPositions = "nodesPositions" @staticmethod - def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]: + def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]: """Return the list of supported features based on a file version. Args: @@ -54,3 +59,110 @@ def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Fe return tuple(features) + +class GraphSerializer: + """Standard Graph serializer.""" + + def __init__(self, graph: "Graph") -> None: + self._graph = graph + + def serialize(self) -> dict: + """ + Serialize the Graph. + """ + return { + GraphIO.Keys.Header: self.serializeHeader(), + GraphIO.Keys.Graph: self.serializeContent(), + } + + @property + def nodes(self) -> list[Node]: + return self._graph.nodes + + def serializeHeader(self) -> dict: + """Build and return the graph serialization header. + + The header contains metadata about the graph, such as the: + - version of the software used to create it. + - version of the file format. + - version of the nodes types used in the graph. + - template flag. + + Args: + nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified. + template: Whether the graph is going to be serialized as a template. + """ + header: dict[str, Any] = {} + header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__ + header[GraphIO.Keys.FileVersion] = GraphIO.__version__ + header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions() + return header + + def _getNodeTypesVersions(self) -> dict[str, str]: + """Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances.""" + nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)]) + nodeTypesVersions = { + nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes + } + # Sort them by name (to avoid random order changing from one save to another). + return dict(sorted(nodeTypesVersions.items())) + + def serializeContent(self) -> dict: + """Graph content serialization logic.""" + return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)} + + def serializeNode(self, node: Node) -> dict: + """Node serialization logic.""" + return node.toDict() + + +class TemplateGraphSerializer(GraphSerializer): + """Serializer for serializing a graph as a template.""" + + def serializeHeader(self) -> dict: + header = super().serializeHeader() + header["template"] = True + return header + + def serializeNode(self, node: Node) -> dict: + """Adapt node serialization to template graphs. + + Instead of getting all the inputs and internal attribute keys, only get the keys of + the attributes whose value is not the default one. + The output attributes, UIDs, parallelization parameters and internal folder are + not relevant for templates, so they are explicitly removed from the returned dictionary. + """ + # For now, implemented as a post-process to update the default serialization. + nodeData = super().serializeNode(node) + + inputKeys = list(nodeData["inputs"].keys()) + + internalInputKeys = [] + internalInputs = nodeData.get("internalInputs", None) + if internalInputs: + internalInputKeys = list(internalInputs.keys()) + + for attrName in inputKeys: + attribute = node.attribute(attrName) + # check that attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["inputs"][attrName] + + for attrName in internalInputKeys: + attribute = node.internalAttribute(attrName) + # check that internal attribute is not a link for choice attributes + if attribute.isDefault and not attribute.isLink: + del nodeData["internalInputs"][attrName] + + # If all the internal attributes are set to their default values, remove the entry + if len(nodeData["internalInputs"]) == 0: + del nodeData["internalInputs"] + + del nodeData["outputs"] + del nodeData["uid"] + del nodeData["internalFolder"] + del nodeData["parallelization"] + + return nodeData + + From 4aa6eac9139f78a9249a9a61717f06896a974086 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 09/18] [core][graphIO] Introduce PartialGraphSerializer Add a new serializer class to manage partial graph serialization logic, ensuring to remove link expressions on attributes refering to nodes that are not in the subset of nodes to serialize. --- meshroom/core/graph.py | 13 ++++++++- meshroom/core/graphIO.py | 60 ++++++++++++++++++++++++++++++++++++++++ tests/test_graphIO.py | 47 +++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 0c09255118..7056973c25 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -17,7 +17,7 @@ from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit -from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer +from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer from meshroom.core.node import Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -1392,6 +1392,17 @@ def serialize(self, asTemplate: bool = False) -> dict: SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer return SerializerClass(self).serialize() + def serializePartial(self, nodes: list[Node]) -> dict: + """Partially serialize this graph considering only the given list of `nodes`. + + Args: + nodes: The list of nodes to serialize. + + Returns: + The serialized graph data. + """ + return PartialGraphSerializer(self, nodes=nodes).serialize() + def save(self, filepath=None, setupProjectFile=True, template=False): path = filepath or self._filepath if not path: diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py index bc65629212..5f81db2648 100644 --- a/meshroom/core/graphIO.py +++ b/meshroom/core/graphIO.py @@ -3,6 +3,7 @@ import meshroom from meshroom.core import Version +from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute from meshroom.core.node import Node if TYPE_CHECKING: @@ -166,3 +167,62 @@ def serializeNode(self, node: Node) -> dict: return nodeData +class PartialGraphSerializer(GraphSerializer): + """Serializer to serialize a partial graph (a subset of nodes).""" + + def __init__(self, graph: "Graph", nodes: list[Node]): + super().__init__(graph) + self._nodes = nodes + + @property + def nodes(self) -> list[Node]: + """Override to consider only the subset of nodes.""" + return self._nodes + + def serializeNode(self, node: Node) -> dict: + """Adapt node serialization to partial graph serialization.""" + # NOTE: For now, implemented as a post-process to the default serialization. + nodeData = super().serializeNode(node) + + # Override input attributes with custom serialization logic, to handle attributes + # connected to nodes that are not in the list of nodes to serialize. + for attributeName in nodeData["inputs"]: + nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName)) + + # Clear UID for non-compatibility nodes, as the custom attribute serialization + # can be impacting the UID by removing connections to missing nodes. + if not node.isCompatibilityNode: + del nodeData["uid"] + + return nodeData + + def _serializeAttribute(self, attribute: Attribute) -> Any: + """ + Serialize `attribute` (recursively for list/groups) and deal with attributes being connected + to nodes that are not part of the partial list of nodes to serialize. + """ + # If the attribute is connected to a node that is not in the list of nodes to serialize, + # the link expression should not be serialized. + if attribute.isLink and attribute.getLinkParam().node not in self.nodes: + # If part of a list, this entry can be discarded. + if isinstance(attribute.root, ListAttribute): + return None + # Otherwise, return the default value for this attribute. + return attribute.defaultValue() + + if isinstance(attribute, ListAttribute): + # Recusively serialize each child of the ListAttribute, skipping those for which the attribute + # serialization logic above returns None. + return [ + exportValue + for child in attribute + if (exportValue := self._serializeAttribute(child)) is not None + ] + + if isinstance(attribute, GroupAttribute): + # Recursively serialize each child of the group attribute. + return {name: self._serializeAttribute(child) for name, child in attribute.value.items()} + + return attribute.getExportValue() + + diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py index b374275908..5e7da48eb0 100644 --- a/tests/test_graphIO.py +++ b/tests/test_graphIO.py @@ -167,3 +167,50 @@ def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk assert not compareGraphsContent(graph, otherGraph) +class TestGraphPartialSerialization: + def test_emptyGraph(self): + graph = Graph("") + serializedGraph = graph.serializePartial([]) + + otherGraph = Graph("") + otherGraph._deserialize(serializedGraph) + assert compareGraphsContent(graph, otherGraph) + + def test_serializeAllNodesIsSimilarToStandardSerialization(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + partialSerializedGraph = graph.serializePartial([nodeA, nodeB]) + standardSerializedGraph = graph.serialize() + + graphA = Graph("") + graphA._deserialize(partialSerializedGraph) + + graphB = Graph("") + graphB._deserialize(standardSerializedGraph) + + assert compareGraphsContent(graph, graphA) + assert compareGraphsContent(graphA, graphB) + + def test_serializeSingleNodeWithInputConnection(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + serializedGraph = graph.serializePartial([nodeB]) + + otherGraph = Graph("") + otherGraph._deserialize(serializedGraph) + + assert len(otherGraph.compatibilityNodes) == 0 + assert len(otherGraph.nodes) == 1 + assert len(otherGraph.edges) == 0 From 44dc09b6755a5cd60f1447b4e5dd719ddb7a8363 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 10/18] [core] Graph: improve uid conflicts check on deserialization Only perform uid check when we have both a serialized and a computed UID. If the node has not been serialized with a UID, it means that it does not expect to match a specific value on deserialization. --- meshroom/core/graph.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 7056973c25..94a10ba910 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -340,7 +340,7 @@ def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {}) return nodeVersions.get(nodeType, default) - def _evaluateUidConflicts(self, data): + def _evaluateUidConflicts(self, graphContent: dict): """ Compare the UIDs of all the nodes in the graph with the UID that is expected in the graph file. If there are mismatches, the nodes with the unexpected UID are replaced with "UidConflict" compatibility nodes. @@ -351,17 +351,17 @@ def _evaluateUidConflicts(self, data): Args: data (dict): the dictionary containing all the nodes to import and their data """ - for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])): + for nodeName, nodeData in sorted(graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])): node = self.node(nodeName) - savedUid = nodeData.get("uid", None) - graphUid = node._uid # Node's UID from the graph itself + serializedUid = nodeData.get("uid", None) + computedUid = node._uid # Node's UID from the graph itself - if savedUid != graphUid and graphUid is not None: + if serializedUid and computedUid and serializedUid != computedUid: # Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode logging.debug("UID conflict detected for {}".format(nodeName)) self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, expectedUid=graphUid) + n = nodeFactory(nodeData, nodeName, expectedUid=computedUid) self._addNode(n, nodeName) else: # f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost: From 957638948e7f35a44d88f37b51eb693bd3e3359b Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 11/18] [ui] Refactor node pasting using graph partial serialization Re-implement node pasting by relying on the graph partial serializer, to serialize only the subset of selected nodes. On pasting, use standard graph deserialization and import the content of the serialized graph in the active graph instance. Simplify the positioning of pasted nodes to only consider mouse position or center of the graph, which works well for the major variety of use-cases. Compute the offset to apply to imported nodes by using the de-serialized graph content's bounding box. --- meshroom/core/graph.py | 35 ------ meshroom/ui/commands.py | 38 +++++- meshroom/ui/graph.py | 129 ++++---------------- meshroom/ui/qml/GraphEditor/GraphEditor.qml | 21 ++-- 4 files changed, 64 insertions(+), 159 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 94a10ba910..6fe9b153a3 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -670,41 +670,6 @@ def duplicateNodes(self, srcNodes): return duplicates - def pasteNodes(self, data, position): - """ - Paste node(s) in the graph with their connections. The connections can only be between - the pasted nodes and not with the rest of the graph. - - Args: - data (dict): the dictionary containing the information about the nodes to paste, with their names and - links already updated to be added to the graph - position (list): the list of positions for each node to paste - - Returns: - list: the list of Node objects that were pasted and added to the graph - """ - nodes = [] - with GraphModification(self): - positionCnt = 0 # always valid because we know the data is sorted the same way as the position list - for key in sorted(data): - nodeType = data[key].get("nodeType", None) - if not nodeType: # this case should never occur, as the data should have been prefiltered first - pass - - attributes = {} - attributes.update(data[key].get("inputs", {})) - attributes.update(data[key].get("outputs", {})) - attributes.update(data[key].get("internalInputs", {})) - - node = Node(nodeType, position=position[positionCnt], **attributes) - self._addNode(node, key) - - nodes.append(node) - positionCnt += 1 - - self._applyExpr() - return nodes - def outEdges(self, attribute): """ Return the list of edges starting from the given attribute """ # type: (Attribute,) -> [Edge] diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 7d8ccc1f52..d5d0abe036 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -211,15 +211,27 @@ class PasteNodesCommand(GraphCommand): """ Handle node pasting in a Graph. """ - def __init__(self, graph, data, position=None, parent=None): + def __init__(self, graph: "Graph", data: dict, position: Position, parent=None): super(PasteNodesCommand, self).__init__(graph, parent) self.data = data self.position = position - self.nodeNames = [] + self.nodeNames: list[str] = [] def redoImpl(self): - data = self.graph.updateImportedProject(self.data) - nodes = self.graph.pasteNodes(data, self.position) + graph = Graph("") + try: + graph._deserialize(self.data) + except: + return False + + boundingBoxCenter = self._boundingBoxCenter(graph.nodes) + offset = Position(self.position.x - boundingBoxCenter.x, self.position.y - boundingBoxCenter.y) + + for node in graph.nodes: + node.position = Position(node.position.x + offset.x, node.position.y + offset.y) + + nodes = self.graph.importGraphContent(graph) + self.nodeNames = [node.name for node in nodes] self.setText("Paste Node{} ({})".format("s" if len(self.nodeNames) > 1 else "", ", ".join(self.nodeNames))) return nodes @@ -228,6 +240,24 @@ def undoImpl(self): for name in self.nodeNames: self.graph.removeNode(name) + def _boundingBox(self, nodes) -> tuple[int, int, int, int]: + if not nodes: + return (0, 0, 0 , 0) + + minX = maxX = nodes[0].x + minY = maxY = nodes[0].y + + for node in nodes[1:]: + minX = min(minX, node.x) + minY = min(minY, node.y) + maxX = max(maxX, node.x) + maxY = max(maxY, node.y) + + return (minX, minY, maxX, maxY) + + def _boundingBoxCenter(self, nodes): + minX, minY, maxX, maxY = self._boundingBox(nodes) + return Position((minX + maxX) / 2, (minY + maxY) / 2) class ImportProjectCommand(GraphCommand): """ diff --git a/meshroom/ui/graph.py b/meshroom/ui/graph.py index 3727fc8ee4..c7dabc14bc 100644 --- a/meshroom/ui/graph.py +++ b/meshroom/ui/graph.py @@ -1050,126 +1050,43 @@ def getSelectedNodesContent(self) -> str: """ if not self._nodeSelection.hasSelection(): return "" - serializedSelection = {node.name: node.toDict() for node in self.iterSelectedNodes()} - return json.dumps(serializedSelection, indent=4) + graphData = self._graph.serializePartial(self.getSelectedNodes()) + return json.dumps(graphData, indent=4) - @Slot(str, QPoint, bool, result=list) - def pasteNodes(self, clipboardContent, position=None, centerPosition=False) -> list[Node]: + @Slot(str, QPoint, result=list) + def pasteNodes(self, serializedData: str, position: Optional[QPoint]=None) -> list[Node]: """ - Parse the content of the clipboard to see whether it contains - valid node descriptions. If that is the case, the nodes described - in the clipboard are built with the available information. - Otherwise, nothing is done. + Import string-serialized graph content `serializedData` in the current graph, optionally at the given + `position`. + If the `serializedData` does not contain valid serialized graph data, nothing is done. - This function does not need to be preceded by a call to "getSelectedNodesContent". - Any clipboard content that contains at least a node type with a valid JSON - formatting (dictionary form with double quotes around the keys and values) - can be used to generate a node. + This method can be used with the result of "getSelectedNodesContent". + But it also accepts any serialized content that matches the graph data or graph content format. For example, it is enough to have: {"nodeName_1": {"nodeType":"CameraInit"}, "nodeName_2": {"nodeType":"FeatureMatching"}} - in the clipboard to create a default CameraInit and a default FeatureMatching nodes. + in `serializedData` to create a default CameraInit and a default FeatureMatching nodes. Args: - clipboardContent (str): the string contained in the clipboard, that may or may not contain valid - node information - position (QPoint): the position of the mouse in the Graph Editor when the function was called - centerPosition (bool): whether the provided position is not the top-left corner of the pasting - zone, but its center + serializedData: The string-serialized graph data. + position: The position where to paste the nodes. If None, the nodes are pasted at (0, 0). Returns: list: the list of Node objects that were pasted and added to the graph """ - if not clipboardContent: - return - try: - d = json.loads(clipboardContent) - except ValueError as e: - raise ValueError(e) - - if not isinstance(d, dict): - raise ValueError("The clipboard does not contain a valid node. Cannot paste it.") - - # If the clipboard contains a header, then a whole file is contained in the clipboard - # Extract the "graph" part and paste it all, ignore the rest - if d.get("header", None): - d = d.get("graph", None) - if not d: - return - - if isinstance(position, QPoint): - position = Position(position.x(), position.y()) - if self.hoveredNode: - # If a node is hovered, add an offset to prevent complete occlusion - position = Position(position.x + self.layout.gridSpacing, position.y + self.layout.gridSpacing) - - # Get the position of the first node in a zone whose top-left corner is the mouse and the bottom-right - # corner the (x, y) coordinates, with x the maximum of all the nodes' position along the x-axis, and y the - # maximum of all the nodes' position along the y-axis. All nodes with a position will be placed relatively - # to the first node within that zone. - firstNodePos = None - minX = 0 - maxX = 0 - minY = 0 - maxY = 0 - for key in sorted(d): - nodeType = d[key].get("nodeType", None) - if not nodeType: - raise ValueError("Invalid node description: no provided node type for '{}'".format(key)) - - pos = d[key].get("position", None) - if pos: - if not firstNodePos: - firstNodePos = pos - minX = pos[0] - maxX = pos[0] - minY = pos[1] - maxY = pos[1] - else: - if minX > pos[0]: - minX = pos[0] - if maxX < pos[0]: - maxX = pos[0] - if minY > pos[1]: - minY = pos[1] - if maxY < pos[1]: - maxY = pos[1] - - # Ensure there will not be an error if no node has a specified position - if not firstNodePos: - firstNodePos = [0, 0] - - # Position of the first node within the zone - position = Position(position.x + firstNodePos[0] - minX, position.y + firstNodePos[1] - minY) - - if centerPosition: # Center the zone around the mouse's position (mouse's position might be artificial) - maxX = maxX + self.layout.nodeWidth # maxX and maxY are the position of the furthest node's top-left corner - maxY = maxY + self.layout.nodeHeight # We want the position of the furthest node's bottom-right corner - position = Position(position.x - ((maxX - minX) / 2), position.y - ((maxY - minY) / 2)) - - finalPosition = None - prevPosition = None - positions = [] - - for key in sorted(d): - currentPosition = d[key].get("position", None) - if not finalPosition: - finalPosition = position - else: - if prevPosition and currentPosition: - # If the nodes both have a position, recreate the distance between them with a different - # starting point - x = finalPosition.x + (currentPosition[0] - prevPosition[0]) - y = finalPosition.y + (currentPosition[1] - prevPosition[1]) - finalPosition = Position(x, y) - else: - # If either the current node or previous one lacks a position, use a custom one - finalPosition = Position(finalPosition.x + self.layout.gridSpacing + self.layout.nodeWidth, finalPosition.y) - prevPosition = currentPosition - positions.append(finalPosition) + graphData = json.loads(serializedData) + except json.JSONDecodeError: + logging.warning("Content is not a valid JSON string.") + return [] + + pos = Position(position.x(), position.y()) if position else Position(0, 0) + result = self.push(commands.PasteNodesCommand(self._graph, graphData, pos)) + if result is False: + logging.warning("Content is not a valid graph data.") + return [] + return result - return self.push(commands.PasteNodesCommand(self.graph, d, position=positions)) undoStack = Property(QObject, lambda self: self._undoStack, constant=True) graphChanged = Signal() diff --git a/meshroom/ui/qml/GraphEditor/GraphEditor.qml b/meshroom/ui/qml/GraphEditor/GraphEditor.qml index c74acbc7d1..1a7813ac5b 100755 --- a/meshroom/ui/qml/GraphEditor/GraphEditor.qml +++ b/meshroom/ui/qml/GraphEditor/GraphEditor.qml @@ -82,25 +82,18 @@ Item { /// Paste content of clipboard to graph editor and create new node if valid function pasteNodes() { - var finalPosition = undefined - var centerPosition = false + let finalPosition = undefined; if (mouseArea.containsMouse) { - if (uigraph.hoveredNode !== null) { - var node = nodeDelegate(uigraph.hoveredNode) - finalPosition = Qt.point(node.mousePosition.x + node.x, node.mousePosition.y + node.y) - } else { - finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY) - } + finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY); } else { - finalPosition = getCenterPosition() - centerPosition = true + finalPosition = getCenterPosition(); } - var copiedContent = Clipboard.getText() - var nodes = uigraph.pasteNodes(copiedContent, finalPosition, centerPosition) + const copiedContent = Clipboard.getText(); + const nodes = uigraph.pasteNodes(copiedContent, finalPosition); if (nodes.length > 0) { - uigraph.selectedNode = nodes[0] - uigraph.selectNodes(nodes) + uigraph.selectedNode = nodes[0]; + uigraph.selectNodes(nodes); } } From 85c0812f676488be3f6243583b125615b64146a8 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 12/18] [core] Add `Graph.copy` method Add a new method to create a copy of a graph instance, relying on chaining serialization and deserialization operations. Add test suite to validate its behavior, and the underlying serialization processes. --- meshroom/core/graph.py | 6 ++++++ tests/test_graphIO.py | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 6fe9b153a3..fc6fc5f19a 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -1345,6 +1345,12 @@ def toDict(self): def asString(self): return str(self.toDict()) + def copy(self) -> "Graph": + """Create a copy of this Graph instance.""" + graph = Graph("") + graph._deserialize(self.serialize()) + return graph + def serialize(self, asTemplate: bool = False) -> dict: """Serialize this Graph instance. diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py index 5e7da48eb0..d2a8330475 100644 --- a/tests/test_graphIO.py +++ b/tests/test_graphIO.py @@ -214,3 +214,30 @@ def test_serializeSingleNodeWithInputConnection(self): assert len(otherGraph.compatibilityNodes) == 0 assert len(otherGraph.nodes) == 1 assert len(otherGraph.edges) == 0 + + +class TestGraphCopy: + def test_graphCopyIsIdenticalToOriginalGraph(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graphCopy = graph.copy() + assert compareGraphsContent(graph, graphCopy) + + def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graphCopy = graph.copy() + assert not compareGraphsContent(graph, graphCopy) + From c97d40f81820c73d88500b1c6b284956dc2be076 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 13/18] [core] Graph: cleanup unused methods --- meshroom/core/graph.py | 133 ----------------------------------------- 1 file changed, 133 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index fc6fc5f19a..bbc0a34191 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -371,139 +371,6 @@ def _evaluateUidConflicts(self, graphContent: dict): n = nodeFactory(nodeData, nodeName) self._addNode(n, nodeName) - def updateImportedProject(self, data): - """ - Update the names and links of the project to import so that it can fit - correctly in the existing graph. - - Parse all the nodes from the project that is going to be imported. - If their name already exists in the graph, replace them with new names, - then parse all the nodes' inputs/outputs to replace the old names with - the new ones in the links. - - Args: - data (dict): the dictionary containing all the nodes to import and their data - - Returns: - updatedData (dict): the dictionary containing all the nodes to import with their updated names and data - """ - nameCorrespondences = {} # maps the old node name to its updated one - updatedData = {} # input data with updated node names and links - - def createUniqueNodeName(nodeNames, inputName): - """ - Create a unique name that does not already exist in the current graph or in the list - of nodes that will be imported. - """ - i = 1 - while i: - newName = "{name}_{index}".format(name=inputName, index=i) - if newName not in nodeNames and newName not in updatedData.keys(): - return newName - i += 1 - - # First pass to get all the names that already exist in the graph, update them, and keep track of the changes - for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - if not isinstance(nodeData, dict): - raise RuntimeError('updateImportedProject error: Node is not a dict.') - - if nodeName in self._nodes.keys() or nodeName in updatedData.keys(): - newName = createUniqueNodeName(self._nodes.keys(), nodeData["nodeType"]) - updatedData[newName] = nodeData - nameCorrespondences[nodeName] = newName - - else: - updatedData[nodeName] = nodeData - - newNames = [nodeName for nodeName in updatedData] # names of all the nodes that will be added - - # Second pass to update all the links in the input/output attributes for every node with the new names - for nodeName, nodeData in updatedData.items(): - nodeType = nodeData.get("nodeType", None) - nodeDesc = meshroom.core.nodesDesc[nodeType] - - inputs = nodeData.get("inputs", {}) - outputs = nodeData.get("outputs", {}) - - if inputs: - inputs = self.updateLinks(inputs, nameCorrespondences) - inputs = self.resetExternalLinks(inputs, nodeDesc.inputs, newNames) - updatedData[nodeName]["inputs"] = inputs - if outputs: - outputs = self.updateLinks(outputs, nameCorrespondences) - outputs = self.resetExternalLinks(outputs, nodeDesc.outputs, newNames) - updatedData[nodeName]["outputs"] = outputs - - return updatedData - - @staticmethod - def updateLinks(attributes, nameCorrespondences): - """ - Update all the links that refer to nodes that are going to be imported and whose - names have to be updated. - - Args: - attributes (dict): attributes whose links need to be updated - nameCorrespondences (dict): node names to replace in the links with the name to replace them with - - Returns: - attributes (dict): the attributes with all the updated links - """ - for key, val in attributes.items(): - for corr in nameCorrespondences.keys(): - if isinstance(val, str) and corr in val: - attributes[key] = val.replace(corr, nameCorrespondences[corr]) - elif isinstance(val, list): - for v in val: - if isinstance(v, str): - if corr in v: - val[val.index(v)] = v.replace(corr, nameCorrespondences[corr]) - else: # the list does not contain strings, so there cannot be links to update - break - attributes[key] = val - - return attributes - - @staticmethod - def resetExternalLinks(attributes, nodeDesc, newNames): - """ - Reset all links to nodes that are not part of the nodes which are going to be imported: - if there are links to nodes that are not in the list, then it means that the references - are made to external nodes, and we want to get rid of those. - - Args: - attributes (dict): attributes whose links might need to be reset - nodeDesc (list): list with all the attributes' description (including their default value) - newNames (list): names of the nodes that are going to be imported; no node name should be referenced - in the links except those contained in this list - - Returns: - attributes (dict): the attributes with all the links referencing nodes outside those which will be imported - reset to their default values - """ - for key, val in attributes.items(): - defaultValue = None - for desc in nodeDesc: - if desc.name == key: - defaultValue = desc.value - break - - if isinstance(val, str): - if Attribute.isLinkExpression(val) and not any(name in val for name in newNames): - if defaultValue is not None: # prevents from not entering condition if defaultValue = '' - attributes[key] = defaultValue - - elif isinstance(val, list): - removedCnt = len(val) # counter to know whether all the list entries will be deemed invalid - tmpVal = list(val) # deep copy to ensure we iterate over the entire list (even if elements are removed) - for v in tmpVal: - if isinstance(v, str) and Attribute.isLinkExpression(v) and not any(name in v for name in newNames): - val.remove(v) - removedCnt -= 1 - if removedCnt == 0 and defaultValue is not None: # if all links were wrong, reset the attribute - attributes[key] = defaultValue - - return attributes def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]: """Import the content (nodes and edges) of another Graph file into this Graph instance. From 5c556f950f1930f17c339fd157c388f1f3dc23a8 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 14/18] [core][graphIO] Add "template" as an explicit key --- meshroom/core/graph.py | 4 ++-- meshroom/core/graphIO.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index bbc0a34191..02840fc321 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -273,7 +273,7 @@ def _deserialize(self, graphData: dict): self.header = graphData.get(GraphIO.Keys.Header, {}) fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0")) graphContent = self._normalizeGraphContent(graphData, fileVersion) - isTemplate = self.header.get("template", False) + isTemplate = self.header.get(GraphIO.Keys.Template, False) with GraphModification(self): # iterate over nodes sorted by suffix index in their names @@ -331,7 +331,7 @@ def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"): # 3. fallback behavior: default to "0.0" if "version" not in nodeData: nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") - inTemplate = fromGraph.header.get("template", False) + inTemplate = fromGraph.header.get(GraphIO.Keys.Template, False) node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) self._addNode(node, nodeName) return node diff --git a/meshroom/core/graphIO.py b/meshroom/core/graphIO.py index 5f81db2648..196888036a 100644 --- a/meshroom/core/graphIO.py +++ b/meshroom/core/graphIO.py @@ -24,6 +24,7 @@ class Keys(object): ReleaseVersion = "releaseVersion" FileVersion = "fileVersion" Graph = "graph" + Template = "template" class Features(Enum): """File Features.""" @@ -122,7 +123,7 @@ class TemplateGraphSerializer(GraphSerializer): def serializeHeader(self) -> dict: header = super().serializeHeader() - header["template"] = True + header[GraphIO.Keys.Template] = True return header def serializeNode(self, node: Node) -> dict: From 7ca0900cbcfee501ea17c457b501874540f5c4a6 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 15/18] [core] Graph: add `replaceNode` method Factorize the logic of replacing a node with another one and re-creating output edges into `Graph.replaceNode` and `Graph._restoreOutEdges`. --- meshroom/core/graph.py | 62 +++++++++++++++++++++++++------------ meshroom/ui/commands.py | 37 ++-------------------- tests/test_compatibility.py | 8 ++--- 3 files changed, 50 insertions(+), 57 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 02840fc321..a8f7c21b14 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -643,7 +643,7 @@ def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str] def node(self, nodeName): return self._nodes.get(nodeName) - def upgradeNode(self, nodeName): + def upgradeNode(self, nodeName) -> Node: """ Upgrade the CompatibilityNode identified as 'nodeName' Args: @@ -663,25 +663,49 @@ def upgradeNode(self, nodeName): if not isinstance(node, CompatibilityNode): raise ValueError("Upgrade is only available on CompatibilityNode instances.") upgradedNode = node.upgrade() + self.replaceNode(nodeName, upgradedNode) + return upgradedNode + + @changeTopology + def replaceNode(self, nodeName: str, newNode: BaseNode): + """Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges. + + Args: + nodeName: The name of the Node to replace. + newNode: The Node instance to replace it with. + """ with GraphModification(self): - inEdges, outEdges, outListAttributes = self.removeNode(nodeName) - self.addNode(upgradedNode, nodeName) - for dst, src in outEdges.items(): - # Re-create the entries in ListAttributes that were completely removed during the call to "removeNode" - # If they are not re-created first, adding their edges will lead to errors - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dst in outListAttributes.keys(): - listAttr = self.attribute(outListAttributes[dst][0]) - if isinstance(outListAttributes[dst][2], list): - listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2] - else: - listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2]) - try: - self.addEdge(self.attribute(src), self.attribute(dst)) - except (KeyError, ValueError) as e: - logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e))) - - return upgradedNode, inEdges, outEdges, outListAttributes + _, outEdges, outListAttributes = self.removeNode(nodeName) + self.addNode(newNode, nodeName) + self._restoreOutEdges(outEdges, outListAttributes) + + def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes): + """Restore output edges that were removed during a call to "removeNode". + + Args: + outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode". + {dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()} + outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected + to a ListAttribute prior to the removal of all edges. + {dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)} + """ + def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any): + listAttr = self.attribute(listAttrName) + if not isinstance(listAttr, ListAttribute): + return + if isinstance(value, list): + listAttr[index:index] = value + else: + listAttr.insert(index, value) + + for dstName, srcName in outEdges.items(): + # Re-create the entries in ListAttributes that were completely removed during the call to "removeNode" + if dstName in outListAttributes: + _recreateTargetListAttributeChildren(*outListAttributes[dstName]) + try: + self.addEdge(self.attribute(srcName), self.attribute(dstName)) + except (KeyError, ValueError) as e: + logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}") def upgradeAllNodes(self): """ Upgrade all upgradable CompatibilityNode instances in the graph. """ diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index d5d0abe036..cfb9f58d2a 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -170,19 +170,7 @@ def undoImpl(self): node = nodeFactory(self.nodeDict, self.nodeName) self.graph.addNode(node, self.nodeName) assert (node.getName() == self.nodeName) - # recreate out edges deleted on node removal - for dstAttr, srcAttr in self.outEdges.items(): - # if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dstAttr in self.outListAttributes.keys(): - listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0]) - if isinstance(self.outListAttributes[dstAttr][2], list): - listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2] - else: - listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2]) - - self.graph.addEdge(self.graph.attribute(srcAttr), - self.graph.attribute(dstAttr)) + self.graph._restoreOutEdges(self.outEdges, self.outListAttributes) class DuplicateNodesCommand(GraphCommand): @@ -451,38 +439,19 @@ def __init__(self, graph, node, parent=None): super(UpgradeNodeCommand, self).__init__(graph, parent) self.nodeDict = node.toDict() self.nodeName = node.getName() - self.outEdges = {} - self.outListAttributes = {} self.setText("Upgrade Node {}".format(self.nodeName)) def redoImpl(self): if not self.graph.node(self.nodeName).canUpgrade: return False - upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName) - return upgradedNode + return self.graph.upgradeNode(self.nodeName) def undoImpl(self): - # delete upgraded node expectedUid = self.graph.node(self.nodeName)._uid - self.graph.removeNode(self.nodeName) # recreate compatibility node with GraphModification(self.graph): - # We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid) - self.graph.addNode(node, self.nodeName) - # recreate out edges - for dstAttr, srcAttr in self.outEdges.items(): - # if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute - # 0 = attribute name, 1 = attribute index, 2 = attribute value - if dstAttr in self.outListAttributes.keys(): - listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0]) - if isinstance(self.outListAttributes[dstAttr][2], list): - listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2] - else: - listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2]) - - self.graph.addEdge(self.graph.attribute(srcAttr), - self.graph.attribute(dstAttr)) + self.graph.replaceNode(self.nodeName, node) class EnableGraphUpdateCommand(GraphCommand): diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index ac7c3002d8..5a5f195918 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -255,7 +255,7 @@ def test_description_conflict(): assert not hasattr(compatNode, "in") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2) assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"] @@ -270,7 +270,7 @@ def test_description_conflict(): assert hasattr(compatNode, "paramA") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3) assert not hasattr(upgradedNode, "paramA") @@ -283,7 +283,7 @@ def test_description_conflict(): assert not hasattr(compatNode, "paramA") # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4) assert hasattr(upgradedNode, "paramA") @@ -303,7 +303,7 @@ def test_description_conflict(): assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__) # perform upgrade - upgradedNode = g.upgradeNode(nodeName)[0] + upgradedNode = g.upgradeNode(nodeName) assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5) assert hasattr(upgradedNode, "paramA") From 42550b61ac824791b36b102f1f49dec65ab06389 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 16/18] [core] Graph: improved uid conflicts evaluation on deserialization At the end of the deserialization process, solve node uid conflicts iteratively by node depths, and only replace the conflicting nodes with a CompatibilityNode. Add new test suite for testing uid conflict handling. --- meshroom/core/graph.py | 91 +++++++++++---------- tests/test_compatibility.py | 152 ++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 41 deletions(-) diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index a8f7c21b14..f2f7d7dfb6 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -4,7 +4,7 @@ import logging import os import re -from typing import Optional +from typing import Any, Optional import weakref from collections import defaultdict, OrderedDict from contextlib import contextmanager @@ -18,7 +18,7 @@ from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer -from meshroom.core.node import Status, Node, CompatibilityNode +from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -285,23 +285,18 @@ def _deserialize(self, graphData: dict): # Create graph edges by resolving attributes expressions self._applyExpr() - # Templates are specific: they contain only the minimal amount of - # serialized data to describe the graph structure. - # They are not meant to be computed: therefore, we can early return here, - # as uid conflict evaluation is only meaningful for nodes with computed data. - if isTemplate: - return + # Templates are specific: they contain only the minimal amount of + # serialized data to describe the graph structure. + # They are not meant to be computed: therefore, we can early return here, + # as uid conflict evaluation is only meaningful for nodes with computed data. + if isTemplate: + return - # By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the - # nodes' links have been resolved and their UID computations are all complete. - # It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones - # that were computed. - self.updateInternals() - self._evaluateUidConflicts(graphContent) - try: - self._applyExpr() - except Exception as e: - logging.warning(e) + # By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the + # nodes' links have been resolved and their UID computations are all complete. + # It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones + # that were computed. + self._evaluateUidConflicts(graphContent) def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict: graphContent = graphData.get(GraphIO.Keys.Graph, graphData) @@ -342,34 +337,48 @@ def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = def _evaluateUidConflicts(self, graphContent: dict): """ - Compare the UIDs of all the nodes in the graph with the UID that is expected in the graph file. If there + Compare the computed UIDs of all the nodes in the graph with the UIDs serialized in `graphContent`. If there are mismatches, the nodes with the unexpected UID are replaced with "UidConflict" compatibility nodes. - Already existing nodes are removed and re-added to the graph identically to preserve all the edges, - which may otherwise be invalidated when a node with output edges but a UID conflict is re-generated as a - compatibility node. - + Args: - data (dict): the dictionary containing all the nodes to import and their data + graphContent: The serialized Graph content. """ - for nodeName, nodeData in sorted(graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])): - node = self.node(nodeName) + def _serializedNodeUidMatchesComputedUid(nodeData: dict, node: BaseNode) -> bool: + """Returns whether the serialized UID matches the one computed in the `node` instance.""" + if isinstance(node, CompatibilityNode): + return True serializedUid = nodeData.get("uid", None) - computedUid = node._uid # Node's UID from the graph itself - - if serializedUid and computedUid and serializedUid != computedUid: - # Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode - logging.debug("UID conflict detected for {}".format(nodeName)) - self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName, expectedUid=computedUid) - self._addNode(n, nodeName) - else: - # f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost: - # the links will be erroneously updated, and any further resolution will fail. - # Recreating the entire graph as it was ensures that all edges will be correctly preserved. - self.removeNode(nodeName) - n = nodeFactory(nodeData, nodeName) - self._addNode(n, nodeName) + computedUid = node._uid + return serializedUid is None or computedUid is None or serializedUid == computedUid + + uidConflictingNodes = [ + node + for node in self.nodes + if not _serializedNodeUidMatchesComputedUid(graphContent[node.name], node) + ] + + if not uidConflictingNodes: + return + + logging.warning("UID Compatibility issues found: recreating conflicting nodes as CompatibilityNodes.") + + # A uid conflict is contagious: if a node has a uid conflict, all of its downstream nodes may be + # impacted as well, as the uid flows through connections. + # Therefore, we deal with conflicting uid nodes by depth: replacing a node with a CompatibilityNode restores + # the serialized uid, which might solve "false-positives" downstream conflicts as well. + nodesSortedByDepth = sorted(uidConflictingNodes, key=lambda node: node.minDepth) + for node in nodesSortedByDepth: + nodeData = graphContent[node.name] + # Evaluate if the node uid is still conflicting at this point, or if it has been resolved by an + # upstream node replacement. + if _serializedNodeUidMatchesComputedUid(nodeData, node): + continue + expectedUid = node._uid + compatibilityNode = nodeFactory(graphContent[node.name], node.name, expectedUid=expectedUid) + # This operation will trigger a graph update that will recompute the uids of all nodes, + # allowing the iterative resolution of uid conflicts. + self.replaceNode(node.name, compatibilityNode) def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]: diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 5a5f195918..07ba2526ce 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -472,3 +472,155 @@ def test_loadsIfValueSetOnIncompatibleAttribute(self, graphSavedOnDisk): replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2) loadGraph(graph.filepath, strictCompatibility=True) + + +class UidTestingNodeV1(desc.Node): + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=True), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class UidTestingNodeV2(desc.Node): + """ + Changes from SampleNodeBV1: + * 'param' has been added + """ + + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=True), + desc.ListAttribute( + name="param", + label="Param", + elementDesc=desc.File( + name="file", + label="File", + description="", + value="", + ), + description="", + ), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class UidTestingNodeV3(desc.Node): + """ + Changes from SampleNodeBV2: + * 'input' is not invalidating the UID. + """ + + inputs = [ + desc.File(name="input", label="Input", description="", value="", invalidate=False), + desc.ListAttribute( + name="param", + label="Param", + elementDesc=desc.File( + name="file", + label="File", + description="", + value="", + ), + description="", + ), + ] + outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)] + + +class TestUidConflict: + def test_changingInvalidateOnAttributeDescCreatesUidConflict(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + node = graph.addNewNode(UidTestingNodeV2.__name__) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + with pytest.raises(GraphCompatibilityError): + loadGraph(graph.filepath, strictCompatibility=True) + + loadedGraph = loadGraph(graph.filepath) + loadedNode = loadedGraph.node(node.name) + assert isinstance(loadedNode, CompatibilityNode) + assert loadedNode.issue == CompatibilityIssue.UidConflict + + def test_uidConflictingNodesPreserveConnectionsOnGraphLoad(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV2.__name__) + + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 2 + + loadedNodeA = loadedGraph.node(nodeA.name) + loadedNodeB = loadedGraph.node(nodeB.name) + + assert loadedNodeB.param.at(0).linkParam == loadedNodeA.output + + def test_upgradingConflictingNodesPreserveConnections(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV2.__name__) + + # Double-connect nodeA.output to nodeB, on both a single attribute and a list attribute + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + graph.addEdge(nodeA.output, nodeB.input) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + def checkNodeAConnectionsToNodeB(): + loadedNodeA = loadedGraph.node(nodeA.name) + loadedNodeB = loadedGraph.node(nodeB.name) + return ( + loadedNodeB.param.at(0).linkParam == loadedNodeA.output + and loadedNodeB.input.linkParam == loadedNodeA.output + ) + + loadedGraph = loadGraph(graph.filepath) + loadedGraph.upgradeNode(nodeA.name) + + assert checkNodeAConnectionsToNodeB() + loadedGraph.upgradeNode(nodeB.name) + + assert checkNodeAConnectionsToNodeB() + assert len(loadedGraph.compatibilityNodes) == 0 + + + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV1, UidTestingNodeV2]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV1.__name__) + + graph.addEdge(nodeA.output, nodeB.input) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 1 + + def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection(self, graphSavedOnDisk): + with registeredNodeTypes([UidTestingNodeV2, UidTestingNodeV3]): + graph: Graph = graphSavedOnDisk + nodeA = graph.addNewNode(UidTestingNodeV2.__name__) + nodeB = graph.addNewNode(UidTestingNodeV3.__name__) + + nodeB.param.append("") + graph.addEdge(nodeA.output, nodeB.param.at(0)) + + graph.save() + replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3) + + loadedGraph = loadGraph(graph.filepath) + assert len(loadedGraph.compatibilityNodes) == 1 From 357575d640e414a29fcef74a3c02f186932e16c3 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 17/18] [commands] UpgradeNode.undo: only set expected uid when "downgrading" UidConflict Only set the expectedUid when undoing the upgrade of a uid conflicting node. Otherwise, let the other type of conflicts take precedence. --- meshroom/ui/commands.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index cfb9f58d2a..d1e8d8bf3e 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -7,7 +7,7 @@ from meshroom.core.attribute import ListAttribute, Attribute from meshroom.core.graph import Graph, GraphModification -from meshroom.core.node import Position +from meshroom.core.node import Position, CompatibilityIssue from meshroom.core.nodeFactory import nodeFactory from meshroom.core.typing import PathLike @@ -439,15 +439,20 @@ def __init__(self, graph, node, parent=None): super(UpgradeNodeCommand, self).__init__(graph, parent) self.nodeDict = node.toDict() self.nodeName = node.getName() + self.compatibilityIssue = None self.setText("Upgrade Node {}".format(self.nodeName)) def redoImpl(self): - if not self.graph.node(self.nodeName).canUpgrade: + if not (node := self.graph.node(self.nodeName)).canUpgrade: return False + self.compatibilityIssue = node.issue return self.graph.upgradeNode(self.nodeName) def undoImpl(self): - expectedUid = self.graph.node(self.nodeName)._uid + expectedUid = None + if self.compatibilityIssue == CompatibilityIssue.UidConflict: + expectedUid = self.graph.node(self.nodeName)._uid + # recreate compatibility node with GraphModification(self.graph): node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid) From dbafe843b9a139a1ca878f2e5c361f3c792b2bc6 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Fri, 13 Dec 2024 20:12:06 +0100 Subject: [PATCH 18/18] [test] Extra partial serialization tests --- tests/test_graphIO.py | 66 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 5 deletions(-) diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py index d2a8330475..65835a5a69 100644 --- a/tests/test_graphIO.py +++ b/tests/test_graphIO.py @@ -13,6 +13,32 @@ class SimpleNode(desc.Node): ] +class NodeWithListAttributes(desc.Node): + inputs = [ + desc.ListAttribute( + name="listInput", + label="List Input", + description="", + elementDesc=desc.File(name="file", label="File", description="", value=""), + exposed=True, + ), + desc.GroupAttribute( + name="group", + label="Group", + description="", + groupDesc=[ + desc.ListAttribute( + name="listInput", + label="List Input", + description="", + elementDesc=desc.File(name="file", label="File", description="", value=""), + exposed=True, + ), + ], + ), + ] + + def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool: """Returns whether the content (node and deges) of two graphs are considered identical. @@ -26,9 +52,10 @@ def _buildNodesSet(graph: Graph): def _buildEdgesSet(graph: Graph): return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges]) - return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet( - graphB - ) + nodesSetA, edgesSetA = _buildNodesSet(graphA), _buildEdgesSet(graphA) + nodesSetB, edgesSetB = _buildNodesSet(graphB), _buildEdgesSet(graphB) + + return nodesSetA == nodesSetB and edgesSetA == edgesSetB class TestImportGraphContent: @@ -197,7 +224,7 @@ def test_serializeAllNodesIsSimilarToStandardSerialization(self): assert compareGraphsContent(graph, graphA) assert compareGraphsContent(graphA, graphB) - def test_serializeSingleNodeWithInputConnection(self): + def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self): graph = Graph("") with registeredNodeTypes([SimpleNode]): @@ -215,6 +242,36 @@ def test_serializeSingleNodeWithInputConnection(self): assert len(otherGraph.nodes) == 1 assert len(otherGraph.edges) == 0 + def test_serializeSingleNodeWithInputConnectionToListAttributeRemovesListEntry(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode, NodeWithListAttributes]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(NodeWithListAttributes.__name__) + + nodeB.listInput.append("") + graph.addEdge(nodeA.output, nodeB.listInput.at(0)) + + otherGraph = Graph("") + otherGraph._deserialize(graph.serializePartial([nodeB])) + + assert len(otherGraph.node(nodeB.name).listInput) == 0 + + def test_serializeSingleNodeWithInputConnectionToNestedListAttributeRemovesListEntry(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode, NodeWithListAttributes]): + nodeA = graph.addNewNode(SimpleNode.__name__) + nodeB = graph.addNewNode(NodeWithListAttributes.__name__) + + nodeB.group.listInput.append("") + graph.addEdge(nodeA.output, nodeB.group.listInput.at(0)) + + otherGraph = Graph("") + otherGraph._deserialize(graph.serializePartial([nodeB])) + + assert len(otherGraph.node(nodeB.name).group.listInput) == 0 + class TestGraphCopy: def test_graphCopyIsIdenticalToOriginalGraph(self): @@ -240,4 +297,3 @@ def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self): graphCopy = graph.copy() assert not compareGraphsContent(graph, graphCopy) -