Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VTA][OpenCL] add device_annot support in graphpack #6125

Merged
merged 6 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhanghaohit what happens if we don't register this annotation op schedule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will raise
AssertionError: on_device doesn't have FTVMStrategy registered
during build_module.cc::Optimize before we do the RunDeviceAnnotationPass.



# zeros
Expand Down
7 changes: 6 additions & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ RELAY_REGISTER_OP("on_device")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});

Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
Expand Down
132 changes: 130 additions & 2 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,100 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
return operator_current_idx


class ExprDeviceAnnot(ExprMutator):
"""Visitor to perform graph annotation on an AST.

Parameters
----------
start: int
the start location to mark run on vta (inclusive)
end: int
the end location to mark run on vta (exclusive)

Returns
---------
None
"""

def __init__(self, start=-1, end=-1):
self.ext_ctx = tvm.context("ext_dev")
self.cpu_ctx = tvm.context("cpu")
self.cast = op.op.get("cast")
self.counter = -1
self.start = start
self.end = end
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_ctx)
return ret

if self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_ctx)
return ret

if self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)

# skip the float op, i.e., float->int cast
if self.is_float_op(call):
return ret

return relay.annotation.on_device(ret, self.ext_ctx)

return relay.Call(self.visit(call.op), args, call.attrs)

def is_float_op(self, call):
"""check if this op belongs to a float op
in general, float op's odtype is float;
a special case is float->int cast, which follow this op sequence:
multiply(float) -> round(float) -> clip(float) -> cast(int);
"""
args = call.args
odtype = _get_tensor_type(call)

if odtype == "float32":
return True

if call.op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True

return False


class ExprLocator(ExprMutator):
"""Visitor to locate op on an AST."""

def __init__(self):
self.counter = -1
self.op2nodes = {}
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

odtype = _get_tensor_type(call)
self.counter += 1
if (call.op, odtype) in self.op2nodes:
self.op2nodes[(call.op, odtype)].append(self.counter)
else:
self.op2nodes[(call.op, odtype)] = [self.counter]

return relay.Call(self.visit(call.op), args, call.attrs)


class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST."""

Expand Down Expand Up @@ -427,6 +521,9 @@ def graph_pack(
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion",
):
"""Pack the graph into batch&channel packed format.

Expand Down Expand Up @@ -464,16 +561,47 @@ def graph_pack(
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.

device_annot: boolean, optional
if we want to annoate the device_type

annot_start_name: str, optional
device annotation start node, from which we mark the nodes as `ext_dev`

annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'

Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert (start_name != stop_name) or (start_name_idx < stop_name_idx)
assert (
(start_name != stop_name)
or (start_name_idx is None != stop_name_idx is None)
or (not (start_name_idx is None and stop_name_idx is None))
or (start_name_idx < stop_name_idx)
)
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(bfactor, cfactor, weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
return run_opt_pass(expr, transform.InferType())
expr = run_opt_pass(expr, transform.InferType())

if device_annot:
expr_locator = ExprLocator()
expr_locator.visit(expr)

annot_start = op.op.get(annot_start_name)
start = expr_locator.op2nodes[(annot_start, "int32")][0]

annot_end = op.op.get(annot_end_name)
# we mark the next op to the last stop_fusion on cpu device
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1

device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
return run_opt_pass(expr, transform.InferType())

return expr