Skip to content

Commit

Permalink
[PYTORCH]Rsub, Embedded, OneHot ops support (apache#5434)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent f782a62 commit 349efab
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,50 @@ def _impl(inputs, input_types):
return _impl


def _rsub():
def _impl(inputs, input_types):
# TODO: Figure out a better way to get typing to work for tensor + scalar
type0 = input_types[0]
if isinstance(inputs[1], _expr.Expr):
type0 = input_types[1]

type1 = input_types[1]
if isinstance(inputs[0], _expr.Expr):
type1 = input_types[0]

data1 = _convert_elemwise_input(inputs[0], type0)
data0 = _convert_elemwise_input(inputs[1], type1)
alpha = _expr.const(float(inputs[2]))

return get_relay_op("subtract")(data0, alpha * data1)
return _impl


def _embedding():
def _impl(inputs, input_types):
weight = inputs[0]
indices = inputs[1]

return _op.take(weight, indices.astype('int32'), axis=0)
return _impl


def _one_hot():
def _impl(inputs, input_types):
indices = inputs[0].astype('int32')
num_classes = inputs[1]
if num_classes == -1:
msg = "Inferring the number of classes is not yet supported."
raise NotImplementedError(msg)

dtype = 'int32'
on_value = tvm.relay.const(1.0, dtype)
off_value = tvm.relay.const(0.0, dtype)

return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype)
return _impl


# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
Expand Down Expand Up @@ -1696,6 +1740,9 @@ def _get_convert_map(prelude):
"aten::Float" : _Float(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::rsub" : _rsub(),
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
Expand Down
53 changes: 53 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,56 @@ def forward(self, *args):
verify_model(Variance5().float().eval(), input_data=input_data)


def test_forward_rsub():
torch.set_grad_enabled(False)

class Rsub1(Module):
def forward(self, *args):
return torch.rsub(args[0], args[1])

class Rsub2(Module):
def forward(self, *args):
return torch.rsub(args[0], args[1], alpha=0.5)

d1 = torch.rand([1, 3]).float()
d2 = torch.rand([1, 3]).float()
d3 = torch.rand([1, 3]).int()
verify_model(Rsub1().float().eval(), input_data=[d1, d2])
verify_model(Rsub1().float().eval(), input_data=[d1, d3])
verify_model(Rsub2().float().eval(), input_data=[d1, d2])
verify_model(Rsub2().float().eval(), input_data=[d1, d3])


def test_forward_embedding():
torch.set_grad_enabled(False)

input_data = torch.randint(0, 10, [2, 4]).long()
verify_model(torch.nn.Embedding(10, 3).float().eval(), input_data=input_data)

input_data = torch.randint(0, 4, [2, 3, 4]).long()
verify_model(torch.nn.Embedding(4, 5, sparse=False).float().eval(), input_data=input_data)

input_data = torch.randint(0, 4, [2, 3, 4]).long()
verify_model(torch.nn.Embedding(4, 5, sparse=True).float().eval(), input_data=input_data)


def test_forward_onehot():
torch.set_grad_enabled(False)

class OneHot1(Module):
def forward(self, *args):
return torch.nn.functional.one_hot(args[0], num_classes=3)

class OneHot2(Module):
def forward(self, *args):
return torch.nn.functional.one_hot(args[0], num_classes=5)

input_data = torch.arange(0, 5) % 3
verify_model(OneHot1().float().eval(), input_data=input_data)

input_data = torch.arange(0, 5) % 4
verify_model(OneHot2().float().eval(), input_data=input_data)


def test_forward_isfinite():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1995,6 +2045,9 @@ def forward(self, *args):
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
Expand Down

0 comments on commit 349efab

Please sign in to comment.