Skip to content

Commit

Permalink
[RELAY][TF] Support symbolic newshape for Reshape (#5429)
Browse files Browse the repository at this point in the history
* [RELAY][TF] Support symbolic newshape for Reshape

* Only need to pass data

* Use MakeReshape() in Reshape()

* Change newshape to Expr

* Create a template for Array<T>

* Fuse reshape when newshape is constant

* Make newshape Optional

* Use bool() of Optional

Co-authored-by: Li Xiaoquan <xiaoquan.li@denglin.ai>
  • Loading branch information
lixiaoquan and Li Xiaoquan authored May 13, 2020
1 parent fc230e0 commit aa42f97
Show file tree
Hide file tree
Showing 18 changed files with 312 additions and 110 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
Optional<Array<Integer>> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if self.operator is op.reshape:
x = self.operator(*args)
else:
x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()})
if isinstance(x, expr.TupleWrapper):
x = x.astuple()
return x
Expand Down
13 changes: 5 additions & 8 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,14 +1155,11 @@ def _impl(inputs, attr, params, mod):
shape_arg = tuple(params_new.asnumpy().astype('int64').flatten())
except Exception:
# Deal with symbolic shape case.
# Currently only shape_of can be the direct ancestor.
if not isinstance(pop_node, tvm.relay.expr.Call) or \
"shape_of" not in str(pop_node.op):
raise RuntimeError("If shape operator is used in reshape to "
"express reshape_like, shape_of must be "
"the direct ancestor of reshape when input "
"shape is symbolic.")
return _op.reshape_like(inputs[0], pop_node.args[0])
if isinstance(pop_node, _expr.Call) and \
"shape_of" in str(pop_node.op):
# shape_of is the direct ancestor.
return _op.reshape_like(inputs[0], pop_node.args[0])
shape_arg = pop_node
return AttrCvt(
op_name="reshape",
extras={'newshape': shape_arg},
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def dense_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0])]
return [reshape_like(grad, orig.args[0]), orig.args[1]]


@register_gradient("cast")
Expand Down
81 changes: 77 additions & 4 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def concatenate_shape_func(attrs, inputs, _):
return [_concatenate_shape_func(inputs, convert(axis))]

@script
def _reshape_shape_func(data_shape, newshape, ndim):
def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out = output_tensor((ndim,), "int64")
src_idx = 0
dst_idx = 0
Expand Down Expand Up @@ -189,10 +189,83 @@ def _reshape_shape_func(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", False)
@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", True)
def reshape_shape_func(attrs, inputs, out_ndims):
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
if attrs.newshape is None:
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
return [_reshape_shape_func_input_shape(inputs[0],
convert(attrs.newshape),
out_ndims[0])]

@script
def _take_no_axis_shape_func(indices_shape, out_ndim):
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def reshape(data, newshape):
data : relay.Expr
The input data to the operator.
newshape : Union[int, Tuple[int], List[int]]
newshape : Union[int, Tuple[int], List[int]] or relay.Expr
The new shape. Should be compatible with the original shape.
Returns
Expand All @@ -210,8 +210,10 @@ def reshape(data, newshape):
The reshaped result.
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make.reshape(data, list(newshape))
newshape = const([newshape])
if isinstance(newshape, (tuple, list)):
newshape = const(list(newshape))
return _make.reshape(data, newshape)

def argwhere(condition):
"""Find the indices of elements of a tensor that are
Expand Down
40 changes: 40 additions & 0 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/pattern_functor.h>

#include "../transforms/pass_util.h"
Expand Down Expand Up @@ -414,5 +415,44 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
return ret;
}

struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<Any>()) {
is_dyn = true;
break;
}
}
}
};

bool IsDynamic(const Type& ty) {
IsDynamicVisitor v;
v.VisitType(ty);
return v.is_dyn;
}

TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

bool IsDataDependant(const CallNode* call) {
static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>("TShapeDataDependant");
Op op = Downcast<Op>(call->op);

if (!tshape_data_dependant.count(op)) {
return false;
}

if (op->name == "reshape") {
if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
if (attrs->newshape) {
// If newshape attribute exists, it isn't data dependant.
return false;
}
}
}

return tshape_data_dependant[op];
}
} // namespace relay
} // namespace tvm
24 changes: 2 additions & 22 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <utility>
#include <vector>

#include "../transforms/pass_util.h"
#include "utils.h"

namespace tvm {
Expand All @@ -70,27 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) {
data_ = std::move(n);
}

struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<Any>()) {
is_dyn = true;
break;
}
}
}
};

bool IsDynamic(const Type& ty) {
IsDynamicVisitor v;
v.VisitType(ty);
return v.is_dyn;
}

// TODO(@jroesch): MOVE ME
TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);

Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// for now, we always use int32 shape when possible
// even if the result of shape inference becomes int64.
Expand Down Expand Up @@ -485,7 +465,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
CHECK_GT(tshape_data_dependant.count(op), 0)
<< "Internal error, cannot find TShapeDataDependant for " << op->name;

data_dependants_.push_back(tshape_data_dependant[op]);
data_dependants_.push_back(IsDataDependant(call_node));
// Visit all inputs
Array<te::Tensor> inputs;
int count_tuple = 0;
Expand Down
Loading

0 comments on commit aa42f97

Please sign in to comment.