forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TOP] Add dense, batchnorm (apache#22)
* [TOP] Add dense, batchnorm * update tvm
- Loading branch information
Showing
14 changed files
with
401 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
from .attr_dict import AttrDict | ||
from . import tensor | ||
from . import nn | ||
from . import transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# pylint: disable=invalid-name, unused-argument | ||
"""Tensor transformation ops""" | ||
from __future__ import absolute_import | ||
|
||
import tvm | ||
from .tensor import _fschedule_broadcast | ||
from ..compiler import registry as reg | ||
from ..compiler import OpPattern | ||
|
||
# Need add reshape, transpose | ||
|
||
def _flatten_index(indices, shape): | ||
"""flatten the index to 1D""" | ||
idx = 0 | ||
for i, value in enumerate(shape): | ||
if i != 0: | ||
idx *= value | ||
idx = idx + indices[i] | ||
return idx | ||
|
||
# reshape | ||
@reg.register_compute("reshape") | ||
def compute_reshape(attrs, inputs, out_info): | ||
"""Compute definition of softmax""" | ||
# TODO(sxj) add support for general reshape | ||
assert len(inputs[0].shape) == 1, "Only support 1d input for now" | ||
oshape = out_info[0].shape | ||
x = inputs[0] | ||
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape))) | ||
reg.register_pattern("reshape", OpPattern.COMPLEX) | ||
reg.register_schedule("reshape", _fschedule_broadcast) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.