-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DYNAMIC] Add Dynamic reshape to a dynamic namespace and add DynamicT…
…oStatic Pass (#5826) * Dynamic reshape passing tests * Add Dynamic to Static Pass * rename test file to prevent pytest conflicts * fix clang build * add nested dynamic shape test * remove cuda tests until VM supports dynamic shapes * rename namespace from dynamic to dyn * fix lint * fix lint again * Remove incorrect doc strings * remove dynamic behavior from standard reshape * fix some tests * merge dynamic and static interfaces in python * fix missing import * missed a reference to relay.dyn.reshape * fix vta example * respond to review comments
- Loading branch information
Matthew Brookhart
authored
Jul 1, 2020
1 parent
5c1bf98
commit b979bf6
Showing
22 changed files
with
625 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=wildcard-import, redefined-builtin, invalid-name | ||
"""The Relay namespace containing dynamic ops.""" | ||
|
||
from . import _transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Constructor APIs""" | ||
import tvm._ffi | ||
|
||
tvm._ffi._init_api("relay.op.dyn._make", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Backend compiler related feature registration""" | ||
# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments | ||
from __future__ import absolute_import | ||
from tvm.te.hybrid import script | ||
from .. import op as _reg | ||
|
||
_reg.register_injective_schedule("dyn.reshape") | ||
|
||
@script | ||
def _reshape_shape_func_input_data(data, newshape, ndim): | ||
out = output_tensor((ndim,), "int64") | ||
data_shape = allocate((len(data.shape),), "int64") | ||
for x in const_range(len(data.shape)): | ||
data_shape[x] = int64(data.shape[x]) | ||
src_idx = 0 | ||
dst_idx = 0 | ||
infer_idx = -1 | ||
copy = False | ||
skip = 0 | ||
for i in const_range(len(newshape)): | ||
if skip > 0: | ||
skip -= 1 | ||
elif newshape[i] > 0: | ||
out[dst_idx] = int64(newshape[i]) | ||
src_idx += 1 | ||
dst_idx += 1 | ||
elif newshape[i] == 0: | ||
out[dst_idx] = data_shape[src_idx] | ||
src_idx += 1 | ||
dst_idx += 1 | ||
elif newshape[i] == -1: | ||
assert infer_idx < 0, "One and only one dim can be inferred" | ||
out[dst_idx] = int64(1) | ||
infer_idx = i | ||
src_idx += 1 | ||
dst_idx += 1 | ||
elif newshape[i] == -2: | ||
assert False, "Value -2 is not valid in newshape argument of dynamic reshape" | ||
elif newshape[i] == -3: | ||
assert data_shape.shape[0] - src_idx > 1, \ | ||
"Not enough dims in input shape for -3" | ||
out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] | ||
src_idx += 2 | ||
dst_idx += 1 | ||
elif newshape[i] == -4: | ||
assert False, "Value -4 is not valid in newshape argument of dynamic reshape" | ||
else: | ||
assert False, "Invalid special values in new shape" | ||
if len(data_shape.shape) > 0: | ||
# if data is not constant, we can then handle -1 and -2 | ||
if copy: | ||
for i in range(src_idx, data_shape.shape[0]): | ||
out[dst_idx] = data_shape[i] | ||
dst_idx += 1 | ||
if infer_idx >= 0: | ||
old_size = int64(1) | ||
for i in const_range(data_shape.shape[0]): | ||
old_size *= data_shape[i] | ||
new_size = int64(1) | ||
for i in const_range(out.shape[0]): | ||
new_size *= out[i] | ||
out[infer_idx] = old_size // new_size | ||
return out | ||
|
||
@_reg.register_shape_func("dyn.reshape", True) | ||
def dynamic_reshape_shape_func(attrs, inputs, out_ndims): | ||
return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.