diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/fuse_relus_clips.py index ad2fdf28ef..1e0fe75178 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/fuse_relus_clips.py @@ -7,6 +7,8 @@ - Clip(Clip(X)) -> Clip """ +from __future__ import annotations + import abc import numpy as np diff --git a/onnxscript/rewriter/redundant_scatter_nd.py b/onnxscript/rewriter/redundant_scatter_nd.py index 1ba6477f52..d7f559fdae 100644 --- a/onnxscript/rewriter/redundant_scatter_nd.py +++ b/onnxscript/rewriter/redundant_scatter_nd.py @@ -29,6 +29,9 @@ def fail(*args): class ScatterAllDynamic(orp.RewriteRuleClassBase): + def __init__(self): + super().__init__(remove_nodes=False) + def pattern(self, op, data, axis, transposed_data, updates): # Construct update-indices spanning an entire axis: shape = op.Shape(data, start=0)