Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,28 @@ def _fill(self, node: fx.Node) -> relax.Var:
value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype)
return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype))

def _full(self, node: fx.Node) -> relax.Var:
import torch

args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype)
return self.block_builder.emit(
relax.op.full(
size,
value,
dtype,
)
)

def _full_like(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
fill_value = relax.const(node.args[1])
return self.block_builder.emit(relax.op.full_like(x, fill_value))

def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1]
Expand All @@ -1292,6 +1314,22 @@ def _new_ones(self, node: fx.Node) -> relax.Var:
)
)

def _ones(self, node: fx.Node) -> relax.Var:
import torch

args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
return self.block_builder.emit(
relax.op.full(
size,
relax.const(1, dtype),
dtype,
)
)

########## DataType ##########

def _to(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,13 @@ def create_convert_map(
"empty.memory_format": self._empty,
"empty_like.default": self._empty_like,
"fill.Scalar": self._fill,
"full.default": self._full,
"full_like.default": self._full_like,
"index_select.default": self._index_select,
"lift_fresh_copy.default": self._to_copy,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
Expand Down
33 changes: 0 additions & 33 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,23 +468,6 @@ def _inplace_fill(self, node: fx.Node) -> relax.Var:
self.env[node.args[0]] = filled
return filled

def _full(self, node: fx.Node) -> relax.Var:
import torch

args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype)
return self.block_builder.emit(
relax.op.full(
size,
value,
dtype,
)
)

def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
Expand Down Expand Up @@ -527,22 +510,6 @@ def _masked_scatter(self, node: fx.Node) -> relax.Var:
mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape))
return self.block_builder.emit(relax.op.where(mask, gathered_source, x))

def _ones(self, node: fx.Node) -> relax.Var:
import torch

args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],))
dtype = self._convert_data_type(
node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
return self.block_builder.emit(
relax.op.full(
size,
relax.const(1, dtype),
dtype,
)
)

def _one_hot(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes")
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,54 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
def test_full(target, dev):
class FullModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.full((2, 3), 3.141592)

torch_module = FullModel().eval()

raw_data = np.random.rand(3, 3).astype("float32")

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_full_like(target, dev):
class FullLike(nn.Module):
def __init__(self):
super().__init__()
self.fill_value = 7.0

def forward(self, x):
return torch.full_like(x, self.fill_value)

torch_module = FullLike().eval()
raw_data = np.random.rand(2, 3).astype("float32")

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_ones(target, dev):
class FullModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ones((2, 3))

torch_module = FullModel().eval()

raw_data = np.random.rand(1, 1).astype("float32")

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_clamp(target, dev):
class ClampBothTensor(torch.nn.Module):
Expand Down
Loading