Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][TF] Support symbolic newshape for Reshape #5429

Merged
merged 11 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,16 @@ using TOpIsStateful = bool;
*/
using TNonComputational = bool;

enum ShapeDependantKind {
kShapeDependantShape = 0,
kShapeDependantData = 1,
kShapeDependantBoth = 2,
lixiaoquan marked this conversation as resolved.
Show resolved Hide resolved
lixiaoquan marked this conversation as resolved.
Show resolved Hide resolved
};

/*!
* \brief Mark the operator whether output shape is data dependant.
*/
using TShapeDataDependant = bool;
using TShapeDependant = int;

/*!
* \brief Computation description interface.
Expand Down Expand Up @@ -236,6 +242,7 @@ using Shape = Array<IndexExpr>;
using FShapeFunc = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Array<te::Tensor>& data_inputs,
const Array<IndexExpr>& out_ndims)>;

} // namespace relay
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if self.operator is op.reshape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we have to maintain a list for symbolic ops here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we might want to handle each symbolic op separately, since they may have different attrs.

x = self.operator(*args)
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,14 +1155,11 @@ def _impl(inputs, attr, params, mod):
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
# Currently only shape_of can be the direct ancestor.
if not isinstance(pop_node, tvm.relay.expr.Call) or \
"shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
if isinstance(pop_node, _expr.Call) and \
"shape_of" in str(pop_node.op):
# shape_of is the direct ancestor.
return _op.reshape_like(inputs[0], pop_node.args[0])
shape_arg = pop_node
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
Expand Down
43 changes: 22 additions & 21 deletions python/tvm/relay/op/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import

from tvm.runtime import convert
from tvm.te.hybrid import script
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from .op import ShapeDependant, register_reduce_schedule, register_shape_func

_reg.register_reduce_schedule("argmax")
_reg.register_reduce_schedule("argmin")
_reg.register_reduce_schedule("sum")
_reg.register_reduce_schedule("all")
_reg.register_reduce_schedule("any")
_reg.register_reduce_schedule("max")
_reg.register_reduce_schedule("min")
_reg.register_reduce_schedule("prod")
_reg.register_reduce_schedule("mean")
_reg.register_reduce_schedule("variance")
register_reduce_schedule("argmax")
register_reduce_schedule("argmin")
register_reduce_schedule("sum")
register_reduce_schedule("all")
register_reduce_schedule("any")
register_reduce_schedule("max")
register_reduce_schedule("min")
register_reduce_schedule("prod")
register_reduce_schedule("mean")
register_reduce_schedule("variance")

def _create_axis_record(attrs, inputs):
axes = attrs.axis if attrs.axis is None else list(get_const_tuple(attrs.axis))
Expand Down Expand Up @@ -79,19 +80,19 @@ def _reduce_shape_func(data_shape, axis_record):

return out

def reduce_shape_func(attrs, inputs, _):
def reduce_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for reduce op.
"""
axis_record = _create_axis_record(attrs, inputs)
return [_reduce_shape_func(inputs[0], convert(axis_record))]

_reg.register_shape_func("argmax", False, reduce_shape_func)
_reg.register_shape_func("argmin", False, reduce_shape_func)
_reg.register_shape_func("all", False, reduce_shape_func)
_reg.register_shape_func("sum", False, reduce_shape_func)
_reg.register_shape_func("max", False, reduce_shape_func)
_reg.register_shape_func("min", False, reduce_shape_func)
_reg.register_shape_func("prod", False, reduce_shape_func)
_reg.register_shape_func("mean", False, reduce_shape_func)
_reg.register_shape_func("variance", False, reduce_shape_func)
register_shape_func("argmax", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("argmin", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("all", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("sum", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("max", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("min", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("prod", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("mean", ShapeDependant.SHAPE, reduce_shape_func)
register_shape_func("variance", ShapeDependant.SHAPE, reduce_shape_func)
86 changes: 43 additions & 43 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from topi.util import get_const_tuple
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
from .op import register_pattern, OpPattern, ShapeDependant


register_broadcast_schedule("log")
Expand Down Expand Up @@ -135,7 +135,7 @@ def _cast_shape_function(x):
out[i] = x[i]
return out

def cast_shape_func(attrs, inputs, out_ndims):
def cast_shape_func(attrs, inputs, _, out_ndims):
return [_cast_shape_function(*inputs)]

@script
Expand All @@ -146,7 +146,7 @@ def _full_shape_func(shape):
out[i] = int64(shape[i])
return out

def full_shape_func(attrs, inputs, out_ndims):
def full_shape_func(attrs, inputs, _, out_ndims):
"""
Shape func for zeros, zeros_like, ones, ones_like.
"""
Expand Down Expand Up @@ -181,53 +181,53 @@ def _broadcast_shape_func(x, y, ndim):
out[ndim-i] = y[ndim2-i]
return out

def broadcast_shape_func(attrs, inputs, out_ndims):
def broadcast_shape_func(attrs, inputs, _, out_ndims):
"""
Shape function for broadcast op.
"""
return [_broadcast_shape_func(*inputs, out_ndims[0])]

def elemwise_shape_func(attrs, inputs, _):
def elemwise_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for elemwise op.
"""
return [topi.math.identity(inputs[0])]

register_shape_func("cast", False, cast_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)

register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
register_shape_func("floor_divide", False, broadcast_shape_func)
register_shape_func("mod", False, broadcast_shape_func)
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("logical_xor", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
register_shape_func("maximum", False, broadcast_shape_func)
register_shape_func("minimum", False, broadcast_shape_func)

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
register_shape_func("exp", False, elemwise_shape_func)
register_shape_func("tan", False, elemwise_shape_func)
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
register_shape_func("cast", ShapeDependant.SHAPE, cast_shape_func)
register_shape_func("zeros", ShapeDependant.SHAPE, full_shape_func)
register_shape_func("zeros_like", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("ones", ShapeDependant.SHAPE, full_shape_func)
register_shape_func("ones_like", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("full", ShapeDependant.SHAPE, full_shape_func)
register_shape_func("full_like", ShapeDependant.SHAPE, elemwise_shape_func)

register_shape_func("add", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("subtract", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("multiply", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("divide", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("floor_divide", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("mod", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("floor_mod", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("logical_and", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("logical_or", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("logical_xor", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("bitwise_not", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("bitwise_and", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("bitwise_or", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("bitwise_xor", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("equal", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("not_equal", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("less", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("less_equal", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("greater", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("greater_equal", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("maximum", ShapeDependant.SHAPE, broadcast_shape_func)
register_shape_func("minimum", ShapeDependant.SHAPE, broadcast_shape_func)

register_shape_func("sqrt", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("negative", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("exp", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("tan", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("fast_exp", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("fast_tanh", ShapeDependant.SHAPE, elemwise_shape_func)
register_shape_func("fast_erf", ShapeDependant.SHAPE, elemwise_shape_func)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def dense_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]
return [reshape_like(grad, orig.args[0]), orig.args[1]]


@register_gradient("cast")
Expand Down
54 changes: 29 additions & 25 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from . import strategy
from .op import OpPattern
from .op import OpPattern, ShapeDependant

_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
Expand Down Expand Up @@ -97,8 +97,8 @@ def _arange_shape_func(start, stop, step):
out[0] = int64(ceil_div((int64(stop[0]) - int64(start[0])), int64(step[0])))
return out

@_reg.register_shape_func("arange", True)
def arange_shape_func(attrs, inputs, _):
@_reg.register_shape_func("arange", ShapeDependant.DATA)
def arange_shape_func(attrs, shape_inputs, inputs, _):
return [_arange_shape_func(*inputs)]

@script
Expand All @@ -117,8 +117,8 @@ def _concatenate_shape_func(inputs, axis):
out[i] += inputs[j][i]
return out

@_reg.register_shape_func("concatenate", False)
def concatenate_shape_func(attrs, inputs, _):
@_reg.register_shape_func("concatenate", ShapeDependant.SHAPE)
def concatenate_shape_func(attrs, inputs, data_inputs, _):
axis = get_const_int(attrs.axis)
return [_concatenate_shape_func(inputs, convert(axis))]

Expand Down Expand Up @@ -189,10 +189,14 @@ def _reshape_shape_func(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
@_reg.register_shape_func("reshape", ShapeDependant.BOTH)
def reshape_shape_func(attrs, shape_inputs, data_inputs, out_ndims):
if len(attrs.newshape):
newshape = convert(get_const_tuple(attrs.newshape))
else:
newshape = data_inputs[1]

return [_reshape_shape_func(shape_inputs[0], newshape, out_ndims[0])]

@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
Expand All @@ -218,7 +222,7 @@ def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
return out

@_reg.register_shape_func("take", False)
def take_shape_func(attrs, inputs, out_ndims):
def take_shape_func(attrs, inputs, _, out_ndims):
"""
Shape function for take op.
"""
Expand Down Expand Up @@ -291,8 +295,8 @@ def _argwhere_shape_func_5d(condition):
out[0] += int64(1)
return out

@_reg.register_shape_func("argwhere", True)
def argwhere_shape_func(attrs, inputs, out_ndims):
@_reg.register_shape_func("argwhere", ShapeDependant.DATA)
def argwhere_shape_func(attrs, _, inputs, out_ndims):
"""
Shape function for argwhere.
"""
Expand Down Expand Up @@ -333,7 +337,7 @@ def _layout_transform_shape_func(data_shape,
return out

@_reg.register_shape_func("layout_transform", False)
def layout_transform_shape_func(attrs, inputs, _):
def layout_transform_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for layout_transform op.
"""
Expand Down Expand Up @@ -410,8 +414,8 @@ def _expand_dim_shape_func(data_shape, ndim, axis, num_newaxis):

return out

@_reg.register_shape_func("expand_dims", False)
def expand_dim_shape_func(attrs, inputs, _):
@_reg.register_shape_func("expand_dims", ShapeDependant.SHAPE)
def expand_dim_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for expand_dim op.
"""
Expand All @@ -433,8 +437,8 @@ def _transpose_shape_func(data_shape, axes):

return out

@_reg.register_shape_func("transpose", False)
def transpose_shape_func(attrs, inputs, _):
@_reg.register_shape_func("transpose", ShapeDependant.SHAPE)
def transpose_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for transpose op.
"""
Expand All @@ -455,8 +459,8 @@ def _squeeze_shape_func(data_shape, keep_axes):

return out

@_reg.register_shape_func("squeeze", False)
def squeeze_shape_func(attrs, inputs, _):
@_reg.register_shape_func("squeeze", ShapeDependant.SHAPE)
def squeeze_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for squeeze op.
"""
Expand Down Expand Up @@ -486,8 +490,8 @@ def _reshape_like_shape_func(target_shape):

return out

@_reg.register_shape_func("reshape_like", False)
def reshape_like_shape_func(attrs, inputs, _):
@_reg.register_shape_func("reshape_like", ShapeDependant.SHAPE)
def reshape_like_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for reshape_like op.
"""
Expand Down Expand Up @@ -516,8 +520,8 @@ def _tile_shape_func(data, reps, ndim, tndim, rndim):
out[i] = int64(reps[i]) * data[i - rgap]
return out

@_reg.register_shape_func("tile", False)
def tile_shape_func(attrs, inputs, _):
@_reg.register_shape_func("tile", ShapeDependant.SHAPE)
def tile_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for tile op.
"""
Expand Down Expand Up @@ -551,8 +555,8 @@ def _split_shape_func(data_shape, index, indices_or_sections, axis):
out[i] = data_shape[i]
return out

@_reg.register_shape_func("split", False)
def split_shape_func(attrs, inputs, _):
@_reg.register_shape_func("split", ShapeDependant.SHAPE)
def split_shape_func(attrs, inputs, data_inputs, _):
"""
Shape function for split op.
"""
Expand Down
Loading