@@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
186186 return node_rewriter
187187
188188
189+ def register_scalarize (
190+ node_rewriter : Union [RewriteDatabase , NodeRewriter , str ], * tags : str , ** kwargs
191+ ):
192+ if isinstance (node_rewriter , str ):
193+
194+ def register (inner_rewriter : Union [RewriteDatabase , Rewriter ]):
195+ return register_specialize (inner_rewriter , node_rewriter , * tags , ** kwargs )
196+
197+ return register
198+ else :
199+ name = kwargs .pop ("name" , None ) or node_rewriter .__name__
200+ compile .optdb ["scalarize" ].register (
201+ name , node_rewriter , "fast_run" , "fast_compile" , * tags , ** kwargs
202+ )
203+ return node_rewriter
204+
205+
189206def register_uncanonicalize (
190207 node_rewriter : Union [RewriteDatabase , NodeRewriter , str ], * tags : str , ** kwargs
191208):
@@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
226243
227244@register_canonicalize
228245@register_specialize
246+ @register_scalarize
229247@node_rewriter ([TensorFromScalar ])
230248def local_tensor_scalar_tensor (fgraph , node ):
231249 """tensor_from_scalar(scalar_from_tensor(x)) -> x"""
232- if isinstance (node .op , TensorFromScalar ):
233- s = node .inputs [0 ]
234- if s .owner and isinstance (s .owner .op , ScalarFromTensor ):
235- t = s .owner .inputs [0 ]
250+ s = node .inputs [0 ]
251+ if s .owner and isinstance (s .owner .op , ScalarFromTensor ):
252+ t = s .owner .inputs [0 ]
236253
237- # We don't need to copy over any stack traces here
238- return [t ]
254+ # We don't need to copy over any stack traces here
255+ return [t ]
239256
240257
241258@register_canonicalize
242259@register_specialize
260+ @register_scalarize
243261@node_rewriter ([ScalarFromTensor ])
244262def local_scalar_tensor_scalar (fgraph , node ):
245- """scalar_from_tensor(tensor_from_scalar(x)) -> x"""
246- if isinstance (node .op , ScalarFromTensor ):
247- t = node .inputs [0 ]
248- if t .owner and isinstance (t .owner .op , TensorFromScalar ):
249- s = t .owner .inputs [0 ]
250-
251- # We don't need to copy over any stack traces here
252- return [s ]
263+ """scalar_from_tensor(tensor_from_scalar(x)) -> x
264+
265+ and scalar_from_tensor(TensorConstant(x)) -> x
266+ """
267+ t = node .inputs [0 ]
268+ if t .owner and isinstance (t .owner .op , TensorFromScalar ):
269+ s = t .owner .inputs [0 ]
270+
271+ # We don't need to copy over any stack traces here
272+ return [s ]
273+ if isinstance (t , TensorConstant ):
274+ assert t .ndim == 0
275+ return [aes .constant (t .value .item (), t .name , t .dtype )]
253276
254277
255278@register_specialize ("local_alloc_elemwise" )
0 commit comments