Skip to content

Commit

Permalink
[DYNAMIC] Add Dynamic reshape to a dynamic namespace and add DynamicT…
Browse files Browse the repository at this point in the history
…oStatic Pass (apache#5826)

* Dynamic reshape passing tests

* Add Dynamic to Static Pass

* rename test file to prevent pytest conflicts

* fix clang build

* add nested dynamic shape test

* remove cuda tests until VM supports dynamic shapes

* rename namespace from dynamic to dyn

* fix lint

* fix lint again

* Remove incorrect doc strings

* remove dynamic behavior from standard reshape

* fix some tests

* merge dynamic and static interfaces in python

* fix missing import

* missed a reference to relay.dyn.reshape

* fix vta example

* respond to review comments
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Jul 14, 2020
1 parent 459a2ed commit 15cdf69
Show file tree
Hide file tree
Showing 22 changed files with 625 additions and 132 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {

/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Optional<Array<Integer>> newshape;
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .op import annotation
from .op import vision
from .op import contrib
from .op import dyn
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def convert(self, v):
def __call__(self, args, attrs, type_args):
if attrs is None:
attrs = {}
if self.operator in (op.reshape, op.strided_slice):
if self.operator in (op.strided_slice,):
x = self.operator(*args)
elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to):
x = self.operator(*args, dtype=attrs["dtype"])
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def batch_matmul_grad(orig, grad):
@register_gradient("reshape")
def reshape_grad(orig, grad):
"""Gradient of reshape"""
return [reshape_like(grad, orig.args[0]), orig.args[1]]
return [reshape_like(grad, orig.args[0])]


@register_gradient("cast")
Expand Down
77 changes: 3 additions & 74 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,82 +273,11 @@ def _reshape_shape_func_input_shape(data_shape, newshape, ndim):
out[infer_idx] = old_size // new_size
return out

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
dst_idx += 1
elif newshape[i] == -2:
copy = True
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
if newshape[i+1] == -1:
assert newshape[i+2] != -1, "Split dims cannot both be -1."
out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2])
out[dst_idx+1] = int64(newshape[i+2])
else:
out[dst_idx] = int64(newshape[i+1])
if newshape[i+2] == -1:
out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1])
else:
out[dst_idx+1] = int64(newshape[i+2])
assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
"Product of split dims doesn't match to input dim"
src_idx += 1
dst_idx += 2
skip = 2
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("reshape", True)
@_reg.register_shape_func("reshape", False)
def reshape_shape_func(attrs, inputs, out_ndims):
if attrs.newshape is None:
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
newshape = get_const_tuple(attrs.newshape)
return [_reshape_shape_func_input_shape(inputs[0],
convert(attrs.newshape),
convert(newshape),
out_ndims[0])]

@script
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay namespace containing dynamic ops."""

from . import _transform
20 changes: 20 additions & 0 deletions python/tvm/relay/op/dyn/_make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relay.op.dyn._make", __name__)
83 changes: 83 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments
from __future__ import absolute_import
from tvm.te.hybrid import script
from .. import op as _reg

_reg.register_injective_schedule("dyn.reshape")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
out = output_tensor((ndim,), "int64")
data_shape = allocate((len(data.shape),), "int64")
for x in const_range(len(data.shape)):
data_shape[x] = int64(data.shape[x])
src_idx = 0
dst_idx = 0
infer_idx = -1
copy = False
skip = 0
for i in const_range(len(newshape)):
if skip > 0:
skip -= 1
elif newshape[i] > 0:
out[dst_idx] = int64(newshape[i])
src_idx += 1
dst_idx += 1
elif newshape[i] == 0:
out[dst_idx] = data_shape[src_idx]
src_idx += 1
dst_idx += 1
elif newshape[i] == -1:
assert infer_idx < 0, "One and only one dim can be inferred"
out[dst_idx] = int64(1)
infer_idx = i
src_idx += 1
dst_idx += 1
elif newshape[i] == -2:
assert False, "Value -2 is not valid in newshape argument of dynamic reshape"
elif newshape[i] == -3:
assert data_shape.shape[0] - src_idx > 1, \
"Not enough dims in input shape for -3"
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
src_idx += 2
dst_idx += 1
elif newshape[i] == -4:
assert False, "Value -4 is not valid in newshape argument of dynamic reshape"
else:
assert False, "Invalid special values in new shape"
if len(data_shape.shape) > 0:
# if data is not constant, we can then handle -1 and -2
if copy:
for i in range(src_idx, data_shape.shape[0]):
out[dst_idx] = data_shape[i]
dst_idx += 1
if infer_idx >= 0:
old_size = int64(1)
for i in const_range(data_shape.shape[0]):
old_size *= data_shape[i]
new_size = int64(1)
for i in const_range(out.shape[0]):
new_size *= out[i]
out[infer_idx] = old_size // new_size
return out

@_reg.register_shape_func("dyn.reshape", True)
def dynamic_reshape_shape_func(attrs, inputs, out_ndims):
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])]
11 changes: 7 additions & 4 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"""Transform operators."""

from . import _make
from ..expr import TupleWrapper, const
from .dyn import _make as _dyn_make
from ..expr import TupleWrapper, const, Expr
from ...tir import expr as _expr


Expand Down Expand Up @@ -210,8 +211,10 @@ def reshape(data, newshape):
result : relay.Expr
The reshaped result.
"""
if isinstance(newshape, Expr):
return _dyn_make.reshape(data, newshape)
if isinstance(newshape, int):
newshape = const([newshape])
newshape = [newshape]
if isinstance(newshape, (tuple, list)):
tempshape = []
for shape in newshape:
Expand All @@ -222,8 +225,8 @@ def reshape(data, newshape):
tempshape.append(int(shape))
except ValueError as err:
raise RuntimeError('Unrecognized shape type: %s' % err)
newshape = const(tempshape)
return _make.reshape(data, newshape)
newshape = tempshape
return _make.reshape(data, list(newshape))

def argwhere(condition):
"""Find the indices of elements of a tensor that are
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,17 @@ def AnnotateTarget(targets):
return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets])


def DynamicToStatic():
"""If possible, convert tvm.relay.dynamic* ops to static versions
Returns
-------
ret : tvm.transform.Pass
The registered pass for dynamic->static conversion.
"""
return _ffi_api.DynamicToStatic()


def Inline():
"""Perform inlining on the given Relay IR module. The global functions that
are marked as `inline` should be always inlined. A cost model will be
Expand Down
9 changes: 1 addition & 8 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,7 @@ bool IsDataDependant(const CallNode* call) {
return false;
}

if (op->name == "reshape") {
if (const auto* attrs = call->attrs.as<ReshapeAttrs>()) {
if (attrs->newshape) {
// If newshape attribute exists, it isn't data dependant.
return false;
}
}
} else if (op->name == "topk") {
if (op->name == "topk") {
if (const auto* attrs = call->attrs.as<TopKAttrs>()) {
if (attrs->k) {
// If k attribute exists, it isn't data dependant.
Expand Down
Loading

0 comments on commit 15cdf69

Please sign in to comment.