Skip to content

Commit

Permalink
add device_annot support in graphpack
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit committed Jul 23, 2020
1 parent 5046ff2 commit b70a5fe
Showing 1 changed file with 133 additions and 3 deletions.
136 changes: 133 additions & 3 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,104 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
operator_current_idx = operator_current_idx + 1
return operator_current_idx


class ExprDeviceAnnot(ExprMutator):
"""Visitor to perform graph annotation on an AST.
Parameters
----------
start: int
the start location to mark run on vta (inclusive)
end: int
the end location to mark run on vta (exclusive)
Returns
---------
None
"""
def __init__(self, start=-1, end=-1):
self.ext_ctx = tvm.context("ext_dev")
self.cpu_ctx = tvm.context("cpu")
self.cast = op.op.get("cast")
self.counter = -1
self.start = start
self.end = end
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
oshape = _get_tensor_shape(call)
odtype = _get_tensor_type(call)
input_types = [arg.checked_type for arg in call.args]
args = [self.visit(arg) for arg in call.args]

self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_ctx)
return ret
elif self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_ctx)
return ret
elif self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)

# skip the float op, i.e., float->int cast
if self.is_float_op(call):
return ret

return relay.annotation.on_device(ret, self.ext_ctx)

return relay.Call(self.visit(call.op), args, call.attrs)

def is_float_op(self, call):
"""check if this op belongs to a float op
in general, float op's odtype is float;
a special case is float->int cast, which follow this op sequence:
multiply(float) -> round(float) -> clip(float) -> cast(int);
"""
args = call.args
odtype = _get_tensor_type(call)
op = call.op

if odtype == "float32":
return True
elif op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True

return False


class ExprLocater(ExprMutator):
"""Visitor to locate op on an AST.
"""
def __init__(self):
self.counter = -1
self.op2nodes = {}
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

odtype = _get_tensor_type(call)
self.counter += 1
if (call.op, odtype) in self.op2nodes:
self.op2nodes[(call.op, odtype)].append(self.counter)
else:
self.op2nodes[(call.op, odtype)] = [self.counter]

return relay.Call(
self.visit(call.op),
args,
call.attrs)


class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
Expand Down Expand Up @@ -412,7 +510,10 @@ def graph_pack(expr,
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False):
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion"):
"""Pack the graph into batch&channel packed format.
Parameters
Expand Down Expand Up @@ -449,18 +550,47 @@ def graph_pack(expr,
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
device_annot: boolean, optional
if we want to annoate the device_type
annot_start_name: str, optional
device annotation start node, from which we mark the nodes as `ext_dev`
annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \
(not (start_name_idx is None and stop_name_idx is None)) or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
bfactor, cfactor,
weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
return run_opt_pass(expr, transform.InferType())
expr = run_opt_pass(expr, transform.InferType())

if device_annot:
expr_locator = ExprLocater()
expr_locator.visit(expr)

annot_start = op.op.get(annot_start_name)
start = expr_locator.op2nodes[(annot_start, "int32")][0]

annot_end = op.op.get(annot_end_name)
# we mark the next op to the last stop_fusion on cpu device
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1

device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
ret = run_opt_pass(expr, transform.InferType())

return ret
else:
return expr

0 comments on commit b70a5fe

Please sign in to comment.