Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DYNAMIC] Add Dynamic reshape to a dynamic namespace and add DynamicToStatic Pass #5826

Merged
merged 17 commits into from
Jul 1, 2020
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Optional<Array<Integer>> newshape;
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .op import annotation
from .op import vision
from .op import contrib
from .op import dyn
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator in (op.reshape, op.strided_slice):
if self.operator in (op.strided_slice,):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def batch_matmul_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0]), orig.args[1]]
return [reshape_like(grad, orig.args[0])]
icemelon marked this conversation as resolved.
Show resolved Hide resolved


@register_gradient("cast")
Expand Down
77 changes: 3 additions & 74 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,82 +273,11 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", True)
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
if attrs.newshape is None:
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func_input_shape(inputs[0],
convert(attrs.newshape),
convert(newshape),
out_ndims[0])]

@script
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic ops."""

from . import _transform
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relay.op.dyn._make", __name__)
83 changes: 83 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_injective_schedule("dyn.reshape")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
src_idx += 1
dst_idx += 1
elif newshape[i] == -2:
assert False, "Value -2 is not valid in newshape argument of dynamic reshape"
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert False, "Value -4 is not valid in newshape argument of dynamic reshape"
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
11 changes: 7 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"""Transform operators."""

from . import _make
from ..expr import TupleWrapper, const
from .dyn import _make as _dyn_make
from ..expr import TupleWrapper, const, Expr
from ...tir import expr as _expr


Expand Down Expand Up @@ -210,8 +211,10 @@ def reshape(data, newshape):
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, Expr):
return _dyn_make.reshape(data, newshape)
if isinstance(newshape, int):
newshape = const([newshape])
newshape = [newshape]
if isinstance(newshape, (tuple, list)):
tempshape = []
for shape in newshape:
Expand All @@ -222,8 +225,8 @@ def reshape(data, newshape):
tempshape.append(int(shape))
except ValueError as err:
raise RuntimeError('Unrecognized shape type: %s' % err)
newshape = const(tempshape)
return _make.reshape(data, newshape)
newshape = tempshape
return _make.reshape(data, list(newshape))

def argwhere(condition):
"""Find the indices of elements of a tensor that are
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,17 @@ def AnnotateTarget(targets):
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])


def DynamicToStatic():
"""If possible, convert tvm.relay.dynamic* ops to static versions

Returns
-------
ret : tvm.transform.Pass
The registered pass for dynamic->static conversion.
"""
return _ffi_api.DynamicToStatic()


def Inline():
"""Perform inlining on the given Relay IR module. The global functions that
are marked as `inline` should be always inlined. A cost model will be
Expand Down
9 changes: 1 addition & 8 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) {
return false;
}

if (op->name == "reshape") {
if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
if (attrs->newshape) {
// If newshape attribute exists, it isn't data dependant.
return false;
}
}
} else if (op->name == "topk") {
if (op->name == "topk") {
if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
if (attrs->k) {
// If k attribute exists, it isn't data dependant.
Expand Down
Loading