Skip to content

Commit

Permalink
[TOPI][Relay][OP] support dynamic NMS(Non Maximum Suppression), symbo…
Browse files Browse the repository at this point in the history
…lic begin, end, and strides for strided_slice (apache#4312)

* [TOPI][Relay][OP] Dynamic NMS and strided_slice

* Incorporate comments

* fix nnvm compatibility issues

* fix InferCorrectLayout

* Minor fix

* fix for fuse

* Workaround to pass batch_size into hybrid function to handle dynamic shape

* Seperate rearrange

* fix lint

* fix ci, comments

* change attr to Optional<T>

* clang format

* remove empty lines

* partial ignore for end of strided_slice

* pylint

* add out_indices for gpu get_valid_counts

* change to slice_mode

* clang-format, fix comments

* fix comment

* change slice_mode to string

* fix CI

* update docstring

Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 18, 2020
1 parent 1c9196f commit 243e5f0
Show file tree
Hide file tree
Showing 43 changed files with 1,123 additions and 352 deletions.
18 changes: 14 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,24 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {

/*! \brief Attributes for StridedSlice operator */
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Array<Integer> begin;
Array<Integer> end;
Array<Integer> strides;
Optional<Array<Integer>> begin;
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
std::string slice_mode;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
TVM_ATTR_FIELD(strides).set_default(Array<Integer>({})).describe("Stride values of the slice");
TVM_ATTR_FIELD(strides).describe(
"Stride values of the slice, a stride can be negative, which causes a reverse slice.");
TVM_ATTR_FIELD(slice_mode)
.set_default("end")
.describe(
"The slice mode [end, size]."
"end - The default slice mode, ending indices for the slice."
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
}
};

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
.describe(
"Max number of output valid boxes for each instance."
"By default all valid boxes are returned.");
TVM_ATTR_FIELD(iou_threshold).set_default(0.5).describe("Non-maximum suppression threshold.");
TVM_ATTR_FIELD(iou_threshold)
.set_default(0.5)
.describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(force_suppress)
.set_default(false)
.describe("Suppress all detections regardless of class_id.");
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator is op.reshape:
if self.operator in (op.reshape, op.strided_slice):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ def _convert_cropping(inexpr, keras_layer, _):
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \
end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r]))


def _convert_batchnorm(inexpr, keras_layer, etab):
Expand Down
25 changes: 17 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,16 +411,22 @@ def _mx_slice(inputs, attrs):
begin = list(attrs.get_int_tuple('begin', None))
end = list(attrs.get_int_tuple('end', None))
stride = attrs.get_int_tuple('step', None)
input_shape = _infer_type(inputs[0]).checked_type.shape
if begin is None:
raise tvm.error.OpAttributeRequired(
'Attribute "begin" not found in operator Slice.')
if end is None:
raise tvm.error.OpAttributeRequired(
'Attribute "end" not found in operator Slice.')
begin = tuple(x if x is not None else 0 for x in begin)
new_attrs = {'begin': begin, 'end': end}
begin = (x if x is not None else 0 for x in begin)
for i, ed in enumerate(end):
if ed is None:
end[i] = input_shape[i]
new_attrs = {'begin': _expr.const(list(begin), dtype="int32"),
'end': _expr.const(list(end), dtype="int32")}
if stride is not None:
new_attrs['strides'] = stride
stride = (x if x is not None else 1 for x in stride)
new_attrs['strides'] = _expr.const(list(stride), dtype="int32")
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -460,7 +466,9 @@ def _mx_slice_axis(inputs, attrs):
else:
begin.append(ax_beg)
end.append(ax_end)
return _op.strided_slice(inputs[0], begin, end)
return _op.strided_slice(inputs[0],
_expr.const(begin, dtype="int32"),
_expr.const(end, dtype="int32"))


def _mx_crop_like(inputs, attrs):
Expand All @@ -480,9 +488,9 @@ def _mx_crop_like(inputs, attrs):
return _op.slice_like(*inputs, **new_attrs)
expr = _infer_type(inputs[1])
like_shape = expr.checked_type.shape
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32")
new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]], dtype="int32")
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -656,7 +664,7 @@ def _mx_multibox_detection(inputs, attrs):

ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1)
return _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **new_attrs1)


def _mx_batch_dot(inputs, attrs):
Expand Down Expand Up @@ -820,6 +828,7 @@ def _mx_box_nms(inputs, attrs):
id_index=id_index, score_index=score_index)
nms_out = _op.vision.non_max_suppression(ret[1],
ret[0],
ret[2],
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,11 +1014,12 @@ def _impl_v1(cls, inputs, attr, params):
attr['ends'] = new_ends
except KeyError:
pass
begin = list(attr['starts'])
end = list(attr['ends'])

return AttrCvt('strided_slice',
transforms={'starts': 'begin',
'ends': 'end'},
ignores=['axes'])(inputs, attr)
return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int32"),
end=_expr.const(end, dtype="int32"))

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -1034,7 +1035,9 @@ def _impl_v10(cls, inputs, attr, params):
starts, ends, axes)
starts = new_starts
ends = new_ends
return _op.strided_slice(inputs[0], begin=starts, end=ends)
return _op.strided_slice(inputs[0],
begin=_expr.const(starts, dtype="int32"),
end=_expr.const(ends, dtype="int32"))


class Gather(OnnxOpConverter):
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,11 @@ def _impl(inputs, input_types):
end[dim] = inputs[3]

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data, begin, end, strides)
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
slice_mode="size")
return _impl

def _split():
Expand Down Expand Up @@ -1263,7 +1267,10 @@ def _impl(inputs, input_types):
end[axis] = i + unif_size
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
chunks.append(chunk_out)

if dim % num_chunks:
Expand All @@ -1273,7 +1280,10 @@ def _impl(inputs, input_types):
end[axis] = dim
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data, begin, end, stride)
chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
chunks.append(chunk_out)

return chunks
Expand Down
84 changes: 70 additions & 14 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,62 @@ def _impl(inputs, attr, params, mod):
return out
return _impl

def _nms():
def _impl(inputs, attr, params, mod):
# Get parameter values
# TODO(yongwww) change nms in relay to support symbolic max_output_size
try:
max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy()
.astype("int64"))[0])
except Exception:
try:
max_output_size = _infer_value(inputs[2], params,
mod).asnumpy().astype("int64").tolist()[0]
except Exception:
max_output_size = -1
iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
# score_threshold was introduced from V3
score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0

# Generate data with shape (1, num_anchors, 5)
scores = AttrCvt(op_name="expand_dims",
ignores=['T_threshold'],
extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr)
data = get_relay_op('concatenate')([scores, inputs[0]], -1)
data = get_relay_op('expand_dims')(data, 0, 1)

# reason why using get_valid_counts is for inference performance
ct, data, indices = get_relay_op('get_valid_counts')(data,
score_threshold=score_threshold,
id_index=-1,
score_index=0)
# TensorFlow NMS doesn't have parameter top_k
top_k = -1
# TF doesn't have class id for nms input
score_index = 0
nms_ret = get_relay_op('non_max_suppression')(data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False)

# squeeze it, TF NMS is not batched
size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

# slice to get the dynamic result
ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
end=size, slice_mode="size")
return ret
return _impl

def _decode_image():
def _impl(inputs, attr, params, mod):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
Expand Down Expand Up @@ -1119,25 +1175,20 @@ def _impl(inputs, attr, params, mod):
try:
begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
begin = _infer_value(inputs[1], params).asnumpy().tolist()[0]
# Handle symbolic begin
try:
begin = _infer_value(inputs[1], params).asnumpy().tolist()
except Exception:
begin = inputs[1]
try:
size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
# Handle symbolic size
try:
size = _infer_value(inputs[2], params).asnumpy().tolist()[0]
size = _infer_value(inputs[2], params).asnumpy().tolist()
except Exception:
size = inputs[2]
data_shape = _infer_shape(inputs[0], mod)
data_dim = len(data_shape)
end = size
if not isinstance(end, (_expr.Call, _expr.Var)):
for i in range(data_dim):
if size[i] == -1:
end[i] = data_shape[i]
else:
end[i] += begin[i]
return _op.strided_slice(inputs[0], begin=begin, end=end)
return _op.strided_slice(inputs[0], begin=begin, end=size, slice_mode="size")
return _impl


Expand Down Expand Up @@ -1466,8 +1517,11 @@ def _transform_mask(stride_dim, ellipsis_mask):
fshape_indices = None
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
out_shape = _infer_shape(out, mod)
out = _op.strided_slice(inputs[0],
begin=begin,
end=end,
strides=stride)
out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))

Expand Down Expand Up @@ -2026,6 +2080,8 @@ def _impl(inputs, attr, params, mod):
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NonMaxSuppressionV2' : _nms(),
'NonMaxSuppressionV3' : _nms(),
'NoOp' : _no_op(),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,7 @@ def convert_detection_postprocess(self, op):

ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
anchor_expr, **multibox_transform_loc_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# keep only the top 'max_detections' rows
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ def conv2d_grad(orig, grad):
assert padded_weight_grad_h >= filter_h
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0],
end=[None, None, filter_h, filter_w])
backward_weight = strided_slice(backward_weight,
begin=const([0, 0, 0, 0], dtype="int64"),
end=const([out_channel, in_channel // attrs.groups,
filter_h, filter_w], dtype="int64"))

return [backward_data, backward_weight]

Expand Down
Loading

0 comments on commit 243e5f0

Please sign in to comment.