Skip to content

Commit

Permalink
[TOPI][OP] cuda for argwhere (#6868)
Browse files Browse the repository at this point in the history
* argwhere

* cuda schedule

* sort argwhere result

* Use single block and thrust to fix flaky behavior

* format

* used dynamic strided_slice

* Fix dynamic strided_slice

* try new strided_slice

* Improve dynamic strided slice to bind data depedent shape var.

* all tests pass

* remove print

* use new strided_slice

* clean

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
  • Loading branch information
zhiics and kevinthesun authored Dec 4, 2020
1 parent a78c695 commit 54cd235
Show file tree
Hide file tree
Showing 10 changed files with 795 additions and 36 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
16 changes: 1 addition & 15 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,7 @@ def compute_strided_set(attrs, inputs, output_type):
_reg.register_pattern("auto_scheduler_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
def compute_argwhere(attrs, inputs, output_type):
"""Compute definition of argwhere"""
output_shape = []
for s in output_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
# see Any, replace it with a var
output_shape.append(te.var("any_dim", "int32"))
new_output_type = tvm.relay.ty.TensorType(output_shape, "int32")
return [topi.argwhere(new_output_type, inputs[0])]


_reg.register_schedule("argwhere", strategy.schedule_argwhere)
_reg.register_strategy("argwhere", strategy.argwhere_strategy)

# scatter
@_reg.register_compute("scatter")
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,15 @@ def correlation_strategy_cuda(attrs, inputs, out_type, target):
name="correlation.cuda",
)
return strategy


@argwhere_strategy.register(["cuda", "gpu"])
def argwhere_strategy_cuda(attrs, inputs, out_type, target):
"""argwhere cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.cuda.argwhere),
wrap_topi_schedule(topi.cuda.schedule_argwhere),
name="argwhere.cuda",
)
return strategy
39 changes: 30 additions & 9 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging

import re
from tvm import topi, _ffi
from tvm import topi, _ffi, te, ir
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple
from tvm.target import generic_func, override_native_generic_func
from .. import op as _op
Expand Down Expand Up @@ -1034,14 +1034,6 @@ def proposal_strategy(attrs, inputs, out_type, target):
return strategy


# argwhere
@generic_func
def schedule_argwhere(attrs, outs, target):
"""schedule argwhere"""
with target:
return topi.generic.schedule_argwhere(outs)


# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
Expand Down Expand Up @@ -1223,3 +1215,32 @@ def correlation_strategy(attrs, inputs, out_type, target):
name="correlation.generic",
)
return strategy


# argwhere
def wrap_compute_argwhere(topi_compute):
"""wrap argwhere topi compute"""

def _compute_argwhere(attrs, inputs, out_type):
output_shape = []
for s in out_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
output_shape.append(te.var("any_dim", "int32"))
new_output_type = ir.TensorType(output_shape, "int32")
return [topi_compute(new_output_type, inputs[0])]

return _compute_argwhere


@override_native_generic_func("argwhere_strategy")
def argwhere_strategy(attrs, inputs, out_type, target):
"""argwhere generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic",
)
return strategy
2 changes: 2 additions & 0 deletions python/tvm/topi/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Argwhere operator"""
import tvm
from tvm.te import hybrid


Expand Down Expand Up @@ -169,6 +170,7 @@ def hybrid_argwhere_5d(output_shape, condition):
return a


@tvm.target.generic_func
def argwhere(output_shape, condition):
"""Find the indices of elements of a tensor that are non-zero.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from .argwhere import *
Loading

0 comments on commit 54cd235

Please sign in to comment.