Skip to content

Commit

Permalink
[PYTORCH]where, addcdiv, addcmul op support (apache#5383)
Browse files Browse the repository at this point in the history
* [PYTORCH]Where, addcdiv, addcmul op support

* Review comments fixed
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent ebb7eec commit 3eb6009
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 31 deletions.
72 changes: 41 additions & 31 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,7 @@ def _impl(inputs, input_types):
def _take():
def _impl(inputs, input_types):
data = inputs[0]
import torch

if isinstance(inputs[1], _expr.Var):
indices = _op.cast(inputs[1], "int32")
elif isinstance(inputs[1], torch.Tensor):
indices = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
raise AssertionError(msg)
indices = _op.cast(inputs[1], "int32")

return _op.transform.take(data, indices=indices)
return _impl
Expand Down Expand Up @@ -337,6 +329,40 @@ def _impl(inputs, input_types):
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl


def _addcdiv():
def _impl(inputs, input_types):
data = inputs[0]
c = _expr.const(inputs[3])
t1 = inputs[1]
t2 = inputs[2]

return data + (c * (t1 / t2))
return _impl


def _addcmul():
def _impl(inputs, input_types):
data = inputs[0]
c = _expr.const(inputs[3])
t1 = inputs[1]
t2 = inputs[2]

return data + (c * (t1 * t2))
return _impl


def _where():
def _impl(inputs, input_types):
cond = inputs[0]
x = inputs[1]
y = inputs[2]

return _op.where(cond, x, y)

return _impl


def _ones():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1388,16 +1414,7 @@ def _impl(inputs, input_types):
def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

rhs = inputs[1]
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")

Expand All @@ -1416,17 +1433,7 @@ def _impl(inputs, input_types):
def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

rhs = _op.cast(rhs, "bool")
rhs = _op.cast(inputs[1], "bool")

return _op.logical_xor(lhs, rhs)
return _impl
Expand Down Expand Up @@ -1557,6 +1564,8 @@ def _get_convert_map(prelude):
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::addcdiv" : _addcdiv(),
"aten::addcmul" : _addcmul(),
"aten::ones" : _ones(),
"aten::ones_like" : _ones_like(),
"aten::zeros" : _zeros(),
Expand All @@ -1576,6 +1585,7 @@ def _get_convert_map(prelude):
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::take" : _take(),
"aten::where" : _where(),
"aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
Expand Down Expand Up @@ -1838,7 +1848,7 @@ def _get_constant(node):
tensor = node.t(attr_name)
if len(tensor.shape) == 0: # tensor(0.1)
return float(tensor)
return tensor
return _wrap_const(tensor.numpy())
elif ty == "DeviceObjType":
return node.s(attr_name)
elif ty == "FunctionType":
Expand Down
69 changes: 69 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,72 @@ def forward(self, *args):
verify_model(Neg1().float().eval(), input_data=input_data)


def test_forward_where():
torch.set_grad_enabled(False)

class Where1(Module):
def forward(self, *args):
y = torch.ones([3, 2])
if torch.cuda.is_available():
y = y.cuda()
return torch.where(args[0] > 0, args[0], y)

class Where2(Module):
def forward(self, *args):
return torch.where(args[0] > 0, args[0], args[1])

x = torch.rand([3, 2]).float()
verify_model(Where1().float().eval(), input_data=[x])
y = torch.rand([3, 2])
verify_model(Where2().float().eval(), input_data=[x, y])


def test_forward_addcdiv():
torch.set_grad_enabled(False)

class Addcdiv1(Module):
def forward(self, *args):
t1 = torch.ones([3, 1])
t2 = torch.ones([1, 3])
if torch.cuda.is_available():
t1 = t1.cuda()
t2 = t2.cuda()
return torch.addcdiv(args[0], 0.1, t1, t2)

class Addcdiv2(Module):
def forward(self, *args):
return torch.addcdiv(args[0], 0.5, args[1], args[2])

input_data = torch.rand([1, 3]).float()
verify_model(Addcdiv1().float().eval(), input_data=input_data)
t1 = torch.rand([3, 1]).float()
t2 = torch.rand([1, 3]).float()
verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2])


def test_forward_addcmul():
torch.set_grad_enabled(False)

class Addcmul1(Module):
def forward(self, *args):
t1 = torch.ones([3, 1])
t2 = torch.ones([1, 3])
if torch.cuda.is_available():
t1 = t1.cuda()
t2 = t2.cuda()
return torch.addcmul(args[0], 0.1, t1, t2)

class Addcmul2(Module):
def forward(self, *args):
return torch.addcmul(args[0], 0.5, args[1], args[2])

input_data = torch.rand([1, 3]).float()
verify_model(Addcmul1().float().eval(), input_data=input_data)
t1 = torch.rand([3, 1]).float()
t2 = torch.rand([1, 3]).float()
verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1944,6 +2010,9 @@ def forward(self, *args):
test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_where()
test_forward_addcdiv()
test_forward_addcmul()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand Down

0 comments on commit 3eb6009

Please sign in to comment.