Skip to content

Commit

Permalink
use elemwise_shape_func for scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jun 8, 2020
1 parent ed8708a commit 0191a8d
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import op as _reg
from . import strategy
from .op import OpPattern
from ._tensor import elemwise_shape_func

_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
Expand Down Expand Up @@ -392,12 +393,7 @@ def argwhere_shape_func(attrs, inputs, out_ndims):
return [_argwhere_shape_func_5d(inputs[0])]
return ValueError("Does not support rank higher than 5 in argwhere")

@_reg.register_shape_func("scatter", True)
def scatter_shape_func(attrs, inputs, out_ndims):
"""
Shape function for scatter.
"""
return inputs[0].shape
_reg.register_shape_func("scatter", False, elemwise_shape_func)

@script
def _layout_transform_shape_func(data_shape,
Expand Down

0 comments on commit 0191a8d

Please sign in to comment.