Skip to content

Commit

Permalink
[Relax][PyTorch] Support binary, statistical and search ops for Expor…
Browse files Browse the repository at this point in the history
…tedProgram importer (#17424)

* support binary ops

* support mean

* support sum

* support argmax and argmin
  • Loading branch information
mshr-h authored Sep 28, 2024
1 parent 176d01e commit 7c28c86
Show file tree
Hide file tree
Showing 4 changed files with 599 additions and 62 deletions.
62 changes: 62 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def convert(node: fx.Node) -> relax.Var:

return convert

########## Binary Ops ##########

def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
def promote_binary_op_args(lhs, rhs):
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
return lhs, rhs
elif isinstance(lhs, relax.Expr):
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
return lhs, relax.const(rhs, lhs.struct_info.dtype)
elif isinstance(rhs, relax.Expr):
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
return relax.const(lhs, rhs.struct_info.dtype), rhs
else:
assert False

def call_binary_op(op, lhs, rhs):
lhs, rhs = promote_binary_op_args(lhs, rhs)
return self.block_builder.emit(op(lhs, rhs))

lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
elif isinstance(lhs, relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
elif isinstance(rhs, relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)

return convert

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -283,6 +316,35 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var:

return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)

########## Statistical ##########

def _mean(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(op(x, dim, keepdim))

return convert

########## Manipulation ##########

def _reshape(self, node: fx.Node) -> relax.Var:
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Callable, Dict, List, Tuple

import torch
Expand Down Expand Up @@ -76,6 +77,8 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr:
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
import operator

return {
# unary
"acos.default": self._unary_op(relax.op.acos),
Expand Down Expand Up @@ -109,11 +112,33 @@ def create_convert_map(
"tanh.default": self._unary_op(relax.op.tanh),
"tril.default": self._tril_triu(relax.op.tril),
"triu.default": self._tril_triu(relax.op.triu),
# binary
"add.Tensor": self._binary_op(relax.op.add, operator.add),
"div.Tensor": self._binary_op(relax.op.divide, operator.truediv),
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
"lt.Scalar": self._binary_op(relax.op.less, operator.lt),
"lt.Tensor": self._binary_op(relax.op.less, operator.lt),
"matmul.default": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
),
"max.other": self._binary_op(relax.op.maximum, max),
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
"pow.Tensor_Scalar": self._binary_op(relax.op.power, operator.pow),
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
# neural network
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"conv2d.default": self._conv2d,
"linear.default": self._linear,
"max_pool2d.default": self._max_pool2d,
# statistical
"mean.dim": self._mean,
"sum.dim_IntList": self._sum,
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"view.default": self._reshape,
}
Expand Down
62 changes: 0 additions & 62 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,39 +96,6 @@ def convert(node: fx.Node) -> relax.Var:

return convert

########## Binary Ops ##########

def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node) -> relax.Var:
def promote_binary_op_args(lhs, rhs):
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
return lhs, rhs
elif isinstance(lhs, relax.Expr):
assert isinstance(lhs.struct_info, relax.TensorStructInfo)
return lhs, relax.const(rhs, lhs.struct_info.dtype)
elif isinstance(rhs, relax.Expr):
assert isinstance(rhs.struct_info, relax.TensorStructInfo)
return relax.const(lhs, rhs.struct_info.dtype), rhs
else:
assert False

def call_binary_op(op, lhs, rhs):
lhs, rhs = promote_binary_op_args(lhs, rhs)
return self.block_builder.emit(op(lhs, rhs))

lhs, rhs = self.retrieve_args(node)
if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
return call_binary_op(relax_op, lhs, rhs)
elif isinstance(lhs, relax.expr.Constant):
return call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype))
elif isinstance(rhs, relax.expr.Constant):
return call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs)
return intrinsic_op(lhs, rhs)

return convert

########## Neural Network ##########

def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -794,35 +761,6 @@ def _unbind(self, node: fx.Node) -> relax.Var:
ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim)))
return self.block_builder.emit(relax.Tuple(ret))

########## Statistical ##########

def _mean(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))

########## Search ##########

def _argmax_argmin(self, op: Callable) -> Callable:
from torch import fx

def convert(node: fx.Node):
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(op(x, dim, keepdim))

return convert

########## Manipulation ##########

def _cat(self, node: fx.Node) -> relax.Var:
Expand Down
Loading

0 comments on commit 7c28c86

Please sign in to comment.