Skip to content

Commit

Permalink
[Relax] Static memory planning (tlc-pack#23)
Browse files Browse the repository at this point in the history
* [BugFix] Normalize in mutator when visiting VarBinding and MatchShape

* Graph memory planning

* Fixes

* Translate memory op to vm builtin op

* Naive test

* Fix importer, use no `iterable`
  • Loading branch information
MasterJH5574 committed Dec 14, 2022
1 parent d7060c9 commit e9895cb
Show file tree
Hide file tree
Showing 8 changed files with 879 additions and 19 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relax/frontend/pytorch_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from numpy import iterable
import torch
import tvm

Expand Down Expand Up @@ -293,7 +292,7 @@ def _getattr(self, node: fx.node.Node) -> relax.Var:

def _getitem(self, node: fx.node.Node) -> relax.Var:
x = self.env[node.args[0]]
if iterable(x):
if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
return x[node.args[1]]
elif isinstance(x, relax.Var):
if isinstance(x.shape, relax.Tuple):
Expand Down Expand Up @@ -453,7 +452,7 @@ def _new_ones(self, node: fx.node.Node) -> relax.Var:
args = self.retrive_args(node)
self_var = args[0]
size = args[1:]
if not iterable(size):
if not isinstance(size, (list, tuple)):
size = (size,)
return self.bb.emit(
relax.op.full(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:
return _ffi_api.CallTIRRewrite() # type: ignore


def VMGraphMemoryPlan() -> tvm.ir.transform.Pass:
return _ffi_api.VMGraphMemoryPlan()


def VMMemoryLower() -> tvm.ir.transform.Pass:
"""Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):

passes = [relax.transform.ToNonDataflow()]
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.VMGraphMemoryPlan())
passes.append(relax.transform.VMMemoryLower())
passes.append(relax.transform.VMShapeLower())
passes.append(relax.transform.AttachGlobalSymbol())
Expand Down
Loading

0 comments on commit e9895cb

Please sign in to comment.