Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][TF] Support symbolic newshape for Reshape #5429

Merged
merged 11 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
Optional<Array<Integer>> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if self.operator is op.reshape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we have to maintain a list for symbolic ops here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now we might want to handle each symbolic op separately, since they may have different attrs.

x = self.operator(*args)
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,14 +1155,11 @@ def _impl(inputs, attr, params, mod):
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
# Currently only shape_of can be the direct ancestor.
if not isinstance(pop_node, tvm.relay.expr.Call) or \
"shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
if isinstance(pop_node, _expr.Call) and \
"shape_of" in str(pop_node.op):
# shape_of is the direct ancestor.
return _op.reshape_like(inputs[0], pop_node.args[0])
shape_arg = pop_node
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def dense_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]
return [reshape_like(grad, orig.args[0]), orig.args[1]]


@register_gradient("cast")
Expand Down
81 changes: 77 additions & 4 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def concatenate_shape_func(attrs, inputs, _):
return [_concatenate_shape_func(inputs, convert(axis))]

@script
def _reshape_shape_func(data_shape, newshape, ndim):
def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
Expand Down Expand Up @@ -189,10 +189,83 @@ def _reshape_shape_func(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", False)
@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", True)
icemelon marked this conversation as resolved.
Show resolved Hide resolved
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
if attrs.newshape is None:
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
return [_reshape_shape_func_input_shape(inputs[0],
convert(attrs.newshape),
out_ndims[0])]

@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def reshape(data, newshape):
data : relay.Expr
The input data to the operator.

newshape : Union[int, Tuple[int], List[int]]
newshape : Union[int, Tuple[int], List[int]] or relay.Expr
The new shape. Should be compatible with the original shape.

Returns
Expand All @@ -210,8 +210,10 @@ def reshape(data, newshape):
The reshaped result.
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make.reshape(data, list(newshape))
newshape = const([newshape])
if isinstance(newshape, (tuple, list)):
newshape = const(list(newshape))
return _make.reshape(data, newshape)

def argwhere(condition):
"""Find the indices of elements of a tensor that are
Expand Down
40 changes: 40 additions & 0 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/pattern_functor.h>

#include "../transforms/pass_util.h"
Expand Down Expand Up @@ -414,5 +415,44 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
return ret;
}

struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<Any>()) {
is_dyn = true;
break;
}
}
}
};

bool IsDynamic(const Type& ty) {
IsDynamicVisitor v;
v.VisitType(ty);
return v.is_dyn;
}

TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

bool IsDataDependant(const CallNode* call) {
static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
Op op = Downcast<Op>(call->op);

if (!tshape_data_dependant.count(op)) {
return false;
}

if (op->name == "reshape") {
if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
if (attrs->newshape) {
// If newshape attribute exists, it isn't data dependant.
return false;
}
}
}

return tshape_data_dependant[op];
}
} // namespace relay
} // namespace tvm
24 changes: 2 additions & 22 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <utility>
#include <vector>

#include "../transforms/pass_util.h"
#include "utils.h"

namespace tvm {
Expand All @@ -70,27 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) {
data_ = std::move(n);
}

struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<Any>()) {
is_dyn = true;
break;
}
}
}
};

bool IsDynamic(const Type& ty) {
IsDynamicVisitor v;
v.VisitType(ty);
return v.is_dyn;
}

// TODO(@jroesch): MOVE ME
TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Expand Down Expand Up @@ -485,7 +465,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
CHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;

data_dependants_.push_back(tshape_data_dependant[op]);
data_dependants_.push_back(IsDataDependant(call_node));
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
Expand Down
Loading