diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index e01eb703cb992..aa5c6d2a22564 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -897,8 +897,14 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, return compute( condition->shape, [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::Select(condition(condition_idx) != 0, x(), y()); + PrimExpr cond; + if (condition->shape.size() == 0) { + cond = condition(); + } else { + Array condition_idx{indices[0]}; + cond = condition(condition_idx); + } + return tvm::tir::Select(cond != 0, x(), y()); }, name, tag); } else if (condition->shape.size() != 1) { @@ -913,9 +919,13 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, }, name, tag); } else { - CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) - << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] - << " vs " << x->shape[0]; + int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]); + int64_t x_first_dim = topi::GetConstInt(x->shape[0]); + if (cond_first_dim > 0 && x_first_dim > 0) { + CHECK_EQ(cond_first_dim, x_first_dim) + << "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim + << " vs " << x_first_dim; + } return compute( x->shape, [&](const Array& indices) { diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 3df582a0c76aa..9671e45a59a3d 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1549,7 +1549,7 @@ def _impl(inputs, attr, params, mod): idx += st # Only return when in_shape is fully static in the range from begin to end. - if idx >= st: + if idx >= ed: ret = _expr.const(out_data, dtype) if shrink_axis_mask: ret = _op.squeeze(ret) @@ -1659,14 +1659,26 @@ def _transform_mask(stride_dim, ellipsis_mask): def _pad(name): def _impl(inputs, attr, params, mod): - padlist = _get_param(params, inputs[1]) - paddings = tuple(tuple(l) for l in padlist) + try: + padlist = _get_param(params, inputs[1]) + except (IndexError, KeyError, AttributeError): + try: + padlist = _infer_value(inputs[1], params, mod).asnumpy().tolist() + except Exception: + padlist = inputs[1] + + if isinstance(padlist, _expr.Expr): + paddings = padlist + else: + paddings = tuple(tuple(l) for l in padlist) attr["pad_width"] = paddings attr["pad_value"] = 0 new_inputs = [inputs[0]] if name == "PadV2": - constant_values = _get_num_param(params, inputs[2]) - attr["pad_value"] = constant_values + try: + attr["pad_value"] = _get_num_param(params, inputs[2]) + except (IndexError, KeyError, AttributeError): + attr["pad_value"] = inputs[2] return AttrCvt( op_name="pad", ignores=["Tpaddings"], diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 8f32d2c6f6526..415529fdcb9a8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -778,6 +778,9 @@ def repeat_shape_func(attrs, inputs, _): @_reg.register_shape_func("broadcast_to_like", False) def broadcast_to_like_shape_func(attrs, inputs, _): + """ + Shape func for broadcast_to_like. + """ return [topi.math.identity(inputs[1])] @@ -798,7 +801,22 @@ def _stack_shape_func(data_shape, axis, num_inputs): @_reg.register_shape_func("stack", False) def stack_shape_func(attrs, inputs, _): + """ + Shape func for stack. + """ axis = get_const_int(attrs.axis) if axis < 0: axis += inputs[0].shape[0] + 1 return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))] + + +@_reg.register_shape_func("where", False) +def where_shape_func(attrs, inputs, _): + """ + Shape func for where. + """ + cond_shape = inputs[0] + x_shape = inputs[1] + out_shape = x_shape if x_shape.shape else cond_shape + + return [topi.math.identity(out_shape)] diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index f18b5397eefe8..cdf0b83190879 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -839,6 +839,7 @@ def test_reshape(): @tvm.testing.uses_gpu def test_where(): + verify_where(()) verify_where((1, 2, 3, 4))