Skip to content

Commit

Permalink
[VTA] Make vta graph_pack compatible with latest TVM, and bring back
Browse files Browse the repository at this point in the history
object detection tutorials.
  • Loading branch information
huajsj committed Aug 12, 2021
1 parent 1abd248 commit fe60bf8
Show file tree
Hide file tree
Showing 2 changed files with 351 additions and 6 deletions.
33 changes: 27 additions & 6 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,24 @@ def _pack_batch_channel(data, dshape, bfactor, cfactor):
return data


def _unpack_batch_channel(data, old_shape):
def _unpack_batch_channel(data, old_shape, unpack_transpose=False):
"""Unpack the data channel dimension."""
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
if unpack_transpose:
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
data = op.reshape(data, newshape=old_shape)
return data


def _channel_const_match(channel_length, cfactor_out):
"""Round the chanel const variant if the value not divisible by cfactor_out"""
diff = int(channel_length) % cfactor_out
if diff != 0:
diff = cfactor_out - diff
channel_length = channel_length + diff

return diff, channel_length


def _const_shape_match(data, dshape, cfactor_out):
"""Pad the constant if the shape[0] not divisible by cfactor_out."""
assert len(dshape) == 3
Expand Down Expand Up @@ -299,6 +310,7 @@ def __init__(self, bfactor, cfactor, weight_bits):
self.upsampling = op.op.get("nn.upsampling")
self.reshape = op.op.get("reshape")
self.number_of_conv2d = 0
self.unpack_transpose = True
super().__init__()

def visit_call(self, call):
Expand All @@ -319,7 +331,7 @@ def visit_call(self, call):
self.start_pack = False
data = args[0]
data_shape = _get_tensor_shape(call.args[0])
return _unpack_batch_channel(data, data_shape)
return _unpack_batch_channel(data, data_shape, self.unpack_transpose)
if self.start_pack:
# Operator cases
if call.op == self.conv2d and odtype == "int32":
Expand Down Expand Up @@ -429,12 +441,12 @@ def visit_call(self, call):
if len(pad_width) == 6:
pass
elif len(pad_width) == 4:
(data,) = args
(data, pad_value) = args
new_pad_width = []
new_pad_width.extend(pad_width)
for _ in range(2):
new_pad_width.append([0, 0])
return op.nn.pad(data, pad_value=call.attrs.pad_value, pad_width=new_pad_width)
return op.nn.pad(data, pad_value=pad_value, pad_width=new_pad_width)
elif call.op == self.upsampling:
(data,) = args
scale_h = call.attrs.scale_h
Expand All @@ -445,8 +457,17 @@ def visit_call(self, call):
return op.nn.upsampling(data, scale_h, scale_w, data_layout, method, align_corners)
elif call.op == self.reshape and len(input_types[0].shape) == 4:
(data,) = args
self.unpack_transpose = False
data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3))
return op.reshape(data, [int(x) for x in input_types[0].shape])
new_shape = [int(x) for x in input_types[0].shape]
# Check if the reshape match with such shape after pad
pad, new_shape[1] = _channel_const_match(new_shape[1], self.cfactor)
data = op.reshape(data, new_shape)
# remove pad data
if pad != 0:
new_pad_width = [[0, 0], [0, -pad], [0, 0], [0, 0]]
data = op.nn.pad(data, pad_width=new_pad_width)
return data

return relay.Call(self.visit(call.op), args, call.attrs)

Expand Down
Loading

0 comments on commit fe60bf8

Please sign in to comment.