Skip to content

Commit

Permalink
Don't fuse take with dynamic inputs (apache#6979)
Browse files Browse the repository at this point in the history
* add a regression test for fusing dynamic take

* add legalize for take that stops fusion on dynamic inputs

* fix lint

* fix typo
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Dec 4, 2020
1 parent 5f37380 commit 48f4f64
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,25 @@ def take_shape_func(attrs, inputs, out_ndims):
return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]


@_reg.register_legalize("take")
def legalize_dyn_topk(attrs, inputs, types):
"""Legalize take op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.take_legalize(attrs, inputs, types)


@script
def _argwhere_shape_func_1d(condition):
out = output_tensor((2,), "int64")
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,28 @@ def take(a, indices, axis=None, mode="clip"):
return cpp.take(a, indices, int(axis), mode)


@tvm.target.generic_func
def take_legalize(attrs, inputs, types):
"""Legalizes dyn.topk op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
if tvm.relay.ty.is_dynamic(types[0]):
return tvm.relay.take(tvm.relay.annotation.stop_fusion(inputs[0]), inputs[1], **attrs)
return None


def gather(data, axis, indices):
"""Gather values along given axis from given indices.
Expand Down
27 changes: 27 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
from tvm import relay
from tvm.relay import transform
Expand Down Expand Up @@ -757,6 +759,31 @@ def create_diamond_func(inp):
assert tvm.ir.structural_equal(fused, expected)


def test_fuse_dynamic_squeeze_slice_take():
input_data = [
np.random.random([1, 2, 4]).astype("float32"),
np.array([0]).astype("int64"),
]

x = relay.var("p0107", shape=(relay.Any(), relay.Any(), 4), dtype="float32")
take_val = relay.var("p166", shape=(relay.Any(),), dtype="int64")

squeeze = relay.op.squeeze(x, axis=[0])
strided_slice = relay.op.strided_slice(
squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1]
)
take = relay.op.take(strided_slice, take_val, axis=0)

mod = tvm.IRModule.from_expr(take)
ex = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm")

result = ex.evaluate()(*input_data)

np_result = np.squeeze(input_data[0][:, input_data[1][0], :], axis=0)

assert np.allclose(result.asnumpy(), np_result)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand Down

0 comments on commit 48f4f64

Please sign in to comment.