diff --git a/meshroom/core/desc.py b/meshroom/core/desc.py index 07466d15f7..7618fa671f 100755 --- a/meshroom/core/desc.py +++ b/meshroom/core/desc.py @@ -38,8 +38,13 @@ def validateValue(self, value): """ return value - def matchDescription(self, value): - """ Returns whether the value perfectly match attribute's description. """ + def matchDescription(self, value, conform=False): + """ Returns whether the value perfectly match attribute's description. + + Args: + value: the value + conform: try to adapt value to match the description + """ try: self.validateValue(value) except ValueError: @@ -66,13 +71,13 @@ def validateValue(self, value): raise ValueError('ListAttribute only supports list/tuple input values (param:{}, value:{}, type:{})'.format(self.name, value, type(value))) return value - def matchDescription(self, value): + def matchDescription(self, value, conform=False): """ Check that 'value' content matches ListAttribute's element description. """ - if not super(ListAttribute, self).matchDescription(value): + if not super(ListAttribute, self).matchDescription(value, conform): return False # list must be homogeneous: only test first element if value: - return self._elementDesc.matchDescription(value[0]) + return self._elementDesc.matchDescription(value[0], conform) return True @@ -97,20 +102,32 @@ def validateValue(self, value): raise ValueError('Value contains key that does not match group description : {}'.format(invalidKeys)) return value - def matchDescription(self, value): + def matchDescription(self, value, conform=False): """ Check that 'value' contains the exact same set of keys as GroupAttribute's group description and that every child value match corresponding child attribute description. + + Args: + value: the value + conform: remove entries that don't exist in the description. """ if not super(GroupAttribute, self).matchDescription(value): return False attrMap = {attr.name: attr for attr in self._groupDesc} - # must have the exact same child attributes - if sorted(value.keys()) != sorted(attrMap.keys()): - return False + + if conform: + # remove invalid keys + invalidKeys = set(value.keys()).difference([attr.name for attr in self._groupDesc]) + for k in invalidKeys: + del self._groupDesc[k] + else: + # must have the exact same child attributes + if sorted(value.keys()) != sorted(attrMap.keys()): + return False + for k, v in value.items(): # each child value must match corresponding child attribute description - if not attrMap[k].matchDescription(v): + if not attrMap[k].matchDescription(v, conform): return False return True diff --git a/meshroom/core/node.py b/meshroom/core/node.py index 0eeb13d402..8fc50d6606 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -880,14 +880,9 @@ def __init__(self, nodeType, nodeDict, position=None, issue=CompatibilityIssue.U self.splitCount = self.parallelization.get("split", 1) self.setSize(self.parallelization.get("size", 1)) - # inputs matching current type description - self._commonInputs = [] # create input attributes for attrName, value in self._inputs.items(): - matchDesc = self._addAttribute(attrName, value, False) - # store attributes that could be used during node upgrade - if matchDesc: - self._commonInputs.append(attrName) + self._addAttribute(attrName, value, False) # create outputs attributes for attrName, value in self.outputs.items(): @@ -951,7 +946,7 @@ def attributeDescFromValue(attrName, value, isOutput): return desc.StringParam(**params) @staticmethod - def attributeDescFromName(refAttributes, name, value): + def attributeDescFromName(refAttributes, name, value, conform=False): """ Try to find a matching attribute description in refAttributes for given attribute 'name' and 'value'. @@ -968,8 +963,9 @@ def attributeDescFromName(refAttributes, name, value): # consider this value matches description: # - if it's a serialized link expression (no proper value to set/evaluate) # - or if it passes the 'matchDescription' test - if attrDesc and (Attribute.isLinkExpression(value) or attrDesc.matchDescription(value)): + if attrDesc and (Attribute.isLinkExpression(value) or attrDesc.matchDescription(value, conform)): return attrDesc + return None def _addAttribute(self, name, val, isOutput): @@ -1043,8 +1039,16 @@ def upgrade(self): if not self.canUpgrade: raise NodeUpgradeError(self.name, "no matching node type") # TODO: use upgrade method of node description if available + + # inputs matching current type description + commonInputs = [] + for attrName, value in self._inputs.items(): + if self.attributeDescFromName(self.nodeDesc.inputs, attrName, value, conform=True): + # store attributes that could be used during node upgrade + commonInputs.append(attrName) + return Node(self.nodeType, position=self.position, - **{key: value for key, value in self.inputs.items() if key in self._commonInputs}) + **{key: value for key, value in self.inputs.items() if key in commonInputs}) compatibilityIssue = Property(int, lambda self: self.issue.value, constant=True) canUpgrade = Property(bool, canUpgrade.fget, constant=True) diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 0e3ede52b8..170192eeee 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -34,6 +34,18 @@ ) ] +#SampleGroupV3 is SampleGroupV2 with one more int parameter +SampleGroupV3 = [ + desc.IntParam(name="a", label="a", description="", value=0, uid=[0], range=None), + desc.IntParam(name="notInSampleGroupV2", label="notInSampleGroupV2", description="", value=0, uid=[0], range=None), + desc.ListAttribute( + name="b", + elementDesc=desc.GroupAttribute(name="p", label="", description="", groupDesc=SampleGroupV1), + label="b", + description="", + ) +] + class SampleNodeV1(desc.Node): """ Version 1 Sample Node """ @@ -103,6 +115,21 @@ class SampleNodeV5(desc.Node): desc.File(name='output', label='Output', description='', value=desc.Node.internalFolder, uid=[]) ] +class SampleNodeV6(desc.Node): + """ + Changes from V5: + * 'paramA' elementDesc has changed from SampleGroupV2 to SampleGroupV3 + """ + inputs = [ + desc.File(name='in', label='Input', description='', value='', uid=[0]), + desc.ListAttribute(name='paramA', label='ParamA', + elementDesc=desc.GroupAttribute( + groupDesc=SampleGroupV3, name='gA', label='gA', description=''), + description='') + ] + outputs = [ + desc.File(name='output', label='Output', description='', value=desc.Node.internalFolder, uid=[]) + ] def test_unknown_node_type(): """ @@ -289,3 +316,48 @@ def test_upgradeAllNodes(): assert n2Name in g.compatibilityNodes.keys() unregisterNodeType(SampleNodeV1) + +def test_conformUpgrade(): + registerNodeType(SampleNodeV5) + registerNodeType(SampleNodeV6) + + g = Graph('') + n1 = g.addNewNode("SampleNodeV5") + n1.paramA.value = [{'a': 0, 'b': [{'a': 0, 'b': [1.0, 2.0]}, {'a': 1, 'b': [1.0, 2.0]}]}] + n1Name = n1.name + graphFile = os.path.join(tempfile.mkdtemp(), "test_conform_upgrade.mg") + g.save(graphFile) + + # replace SampleNodeV5 by SampleNodeV6 + meshroom.core.nodesDesc[SampleNodeV5.__name__] = SampleNodeV6 + + # reload file + g = loadGraph(graphFile) + os.remove(graphFile) + + # node is a CompatibilityNode + assert len(g.compatibilityNodes) == 1 + assert g.node(n1Name).canUpgrade + + # upgrade all upgradable nodes + g.upgradeAllNodes() + + # only the node with an unknown type has not been upgraded + assert len(g.compatibilityNodes) == 0 + + upgradedNode = g.node(n1Name) + + # check upgrade + assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV6) + + # check conformation + assert len(upgradedNode.paramA.value) == 1 + + unregisterNodeType(SampleNodeV5) + unregisterNodeType(SampleNodeV6) + + + + + +