diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 16fa171353..e19bc8c68b 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -136,6 +136,20 @@ def __call__(self, model: ir.Model) -> PassResult: f"The result of the pass '{self.__class__.__name__}' should be type PassResult. " "Please create one with ir.passes.PassResult()." ) + + # Checks that the declared in-place property is respected + if self.in_place and result.model is not model: + raise PassError( + f"The pass '{self.__class__.__name__}' is declared in-place, " + "but the model returned is *not* the same object as the input model. " + "Pass developer: Pass should return the same model object or the in_place property should return False." + ) + if not self.in_place and result.model is model: + raise PassError( + f"The pass '{self.__class__.__name__}' is declared not in-place, " + "but the model returned *is* the same object as the input model. " + "Pass developer: Pass should return a new model object or the in_place property should return True." + ) return result @abc.abstractmethod