Skip to content

Commit

Permalink
[ONNX] [Relay] Dynamic squeeze (apache#9095)
Browse files Browse the repository at this point in the history
* adding dynamic squeeze first steps

* Matt B. implementing shape

* squeeze implemented, dynamic_to_static and onnx importer next

* add Squeeze op convert to onnx.py

* dynamic to static

* removed comments

* removed comments

* added comment

* adjusted comment

* black and lint

* ran make format in root directory

Co-authored-by: CircleSpin <jocelyn@pop-os.localdomain>
  • Loading branch information
CircleSpin and CircleSpin authored Sep 28, 2021
1 parent 9e47b43 commit 5e46e75
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 5 deletions.
22 changes: 20 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,23 @@ def _impl_v12(cls, inputs, attr, params):
return result


class Squeeze(OnnxOpConverter):
"""Operator converter for Squeeze."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
axis = attr.get("axes", None)
return _op.squeeze(*inputs, axis)

@classmethod
def _impl_v13(cls, inputs, attr, params):
axis = inputs[1]
dtype = infer_type(axis).checked_type.dtype
rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype)
axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis)
return _op.squeeze(inputs[0], fold_constant(axis))


class Split(OnnxOpConverter):
"""Operator converter for Split."""

Expand Down Expand Up @@ -2818,7 +2835,8 @@ def _impl_v12(cls, inputs, attr, params):
alpha = _op.const(attr.get("alpha", 1.0), dtype)
zero = _op.const(0, dtype)
one = _op.const(1, dtype)
return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one))
out = _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one))
return out


class MaxRoiPool(OnnxOpConverter):
Expand Down Expand Up @@ -4149,7 +4167,7 @@ def _get_convert_map(opset):
"ScatterElements": Scatter.get_converter(opset),
"ScatterND": ScatterND.get_converter(opset),
"EyeLike": EyeLike.get_converter(opset),
"Squeeze": AttrCvt("squeeze", {"axes": "axis"}),
"Squeeze": Squeeze.get_converter(opset),
"Unsqueeze": Unsqueeze.get_converter(opset),
"Pad": Pad.get_converter(opset),
"Shape": Shape.get_converter(opset),
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_reg.register_broadcast_schedule("dyn.broadcast_to")
_reg.register_injective_schedule("dyn.reshape")
_reg.register_injective_schedule("dyn.expand_dims")
_reg.register_injective_schedule("dyn.squeeze")
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")
_reg.register_injective_schedule("dyn.full")
Expand Down Expand Up @@ -258,3 +259,24 @@ def _sparse_to_dense_shape_func(output_shape, ndim):
@_reg.register_shape_func("dyn.sparse_to_dense", True)
def sparse_to_dense_shape_func(attrs, inputs, out_ndims):
return [_sparse_to_dense_shape_func(inputs[3], out_ndims[0])]


@script
def _squeeze_shape_func_input_data(data, axis, ndims):
out = output_tensor((ndims,), "int64")
out_i = 0
for i in const_range(data.shape[0]):
not_in_axis = True
for j in const_range(axis.shape[0]):
if i == axis[j]:
not_in_axis = False
if not_in_axis:
out[out_i] = int64(data[i])
out_i += 1

return out


@_reg.register_shape_func("dyn.squeeze", [False, True])
def dynamic_squeeze_shape_func(attrs, inputs, out_ndims):
return [_squeeze_shape_func_input_data(inputs[0], inputs[1], out_ndims[0])]
6 changes: 5 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def squeeze(data, axis=None):
data : tvm.relay.Expr
The input data to the operator.
axis : None or List[int]
axis : None or List[int] or Expr
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
Expand All @@ -159,6 +159,10 @@ def squeeze(data, axis=None):
result : tvm.relay.Expr
The squeezed result.
"""
if isinstance(axis, Constant):
axis = list(axis.data.numpy())
if isinstance(axis, Expr):
return _dyn_make.squeeze(data, axis)
return _make.squeeze(data, axis)


Expand Down
57 changes: 57 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,63 @@ RELAY_REGISTER_OP("dyn.expand_dims")
.set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

bool DynSqueezeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// [data, axes, output]
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* axes = types[1].as<TensorTypeNode>();
if (axes == nullptr) {
return false;
}

ICHECK_EQ(axes->shape.size(), 1) << "Got" << axes->shape.size() << "expected 1";
ICHECK(axes->shape[0].as<IntImmNode>()) << "axes expected to be static rank";
size_t output_rank = data->shape.size() - axes->shape[0].as<IntImmNode>()->value;
std::vector<IndexExpr> result_shape(output_rank, Any());
reporter->Assign(types[2], TensorType(result_shape, data->dtype));
return true;
}

Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
ICHECK(out_ttype != nullptr);
Array<IndexExpr> newshape;
for (auto val : out_ttype->shape) {
newshape.push_back(val.as<tir::AnyNode>()->ToVar());
}
return {topi::reshape(inputs[0], newshape)};
}

Expr MakeDynSqueeze(Expr data, Expr axes) {
auto attrs = make_object<SqueezeAttrs>();
static const Op& op = Op::Get("dyn.squeeze");
return Call(op, {data, axes}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.dyn._make.squeeze").set_body_typed(MakeDynSqueeze);

RELAY_REGISTER_OP("dyn.squeeze")
.describe(R"code(Remove axes of value 1 in input tensor at the dimensions given by axes
- **data**: The input data to the operator.
- **axes**: The axes to squeeze.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<SqueezeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("axes", "Tensor", "The axes to squeeze.")
.set_support_level(3)
.add_type_rel("DynSqueeze", DynSqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<TReshapeOp>("TReshapeOp", true);

} // namespace dyn
} // namespace relay
} // namespace tvm
9 changes: 9 additions & 0 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class DynamicToStaticMutator : public MixedModeMutator {
}
return Expr(nullptr);
}},
{Op::Get("dyn.squeeze"),
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
if (const ConstantNode* axis = args[1].as<ConstantNode>()) {
ICHECK_EQ(axis->data->ndim, 1);
return MakeSqueeze(call_node->args[0], ToVector(axis->data));
}
return Expr(nullptr);
}},
{Op::Get("dyn.tile"),
[this](const CallNode* call_node) {
auto args = PrepareArgs(call_node);
Expand Down
2 changes: 0 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4993,8 +4993,6 @@ def verify_eyelike(indata):
"test_split_variable_parts_2d",
"test_split_variable_parts_default_axis",
"test_split_zero_size_splits",
"test_squeeze",
"test_squeeze_negative_axes",
"test_strnormalizer_export_monday_casesensintive_lower",
"test_strnormalizer_export_monday_casesensintive_nochangecase",
"test_strnormalizer_export_monday_casesensintive_upper",
Expand Down
16 changes: 16 additions & 0 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))


def test_squeeze():
def verify_squeeze(shape, dtype, axis):
x = relay.var("x", relay.TensorType(shape, dtype))
assert axis is not None
np_axis = tuple(axis)
axis = relay.var("axis", relay.TensorType([len(axis)], "int64"))
squeeze = relay.squeeze(x, axis=axis)
func = relay.Function([x, axis], squeeze)
x_data = np.random.random_sample(shape).astype(dtype)
ref_res = np.squeeze(x_data, axis=np_axis)
verify_func(func, [x_data, np.array(np_axis).astype("int64")], ref_res)

verify_squeeze((1, 3, 1), "float32", [0])
verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2])


@tvm.testing.uses_gpu
def test_dyn_expand_dims():
def verify_expand_dims(
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_pass_dynamic_to_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,31 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((4, 7), (2, 7, 2), (2, 7, 2))


@tvm.testing.uses_gpu
def test_dynamic_to_static_squeeze():
def verify_squeeze(shape, axis, oshape):
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(axis, "float32"))
z = relay.squeeze(x, relay.shape_of(y))
func = run_infer_type(relay.Function([x, y], z))
func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType())

zz = func2.body
assert isinstance(zz, relay.Call)
assert zz.op == relay.op.get("squeeze")
assert "axis=" in zz.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=axis).astype("float32")
ref_res = np.squeeze(x_data, axis)
verify_func(func2, [x_data, y_data], ref_res)

verify_squeeze((1, 3, 4, 1), (0,), (3, 4, 1))
verify_squeeze((1, 3, 4, 1), (3,), (1, 3, 4))
verify_squeeze((1, 3, 4, 1), (0, 3), (3, 4))


@tvm.testing.uses_gpu
def test_dynamic_to_static_double_reshape():
def verify_reshape(shape, newshape):
Expand Down

0 comments on commit 5e46e75

Please sign in to comment.