Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is_floating_point() test and better type support in verify_model_vm() #7134

Merged
merged 7 commits into from
Dec 22, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,9 +1889,10 @@ def _get_default_vm_targets():
return [tgt for (tgt, _) in tvm.testing.enabled_targets()]


def verify_script_model(pt_model, ishapes, targets):
def verify_script_model(pt_model, ishapes, targets, idtype=None):
script_module = torch.jit.script(pt_model)
verify_model_vm(script_module, ishapes, targets=targets)

verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this if/else, just verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets) even if idtype is None should work.

Copy link
Contributor Author

@TylerADavis TylerADavis Dec 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've cleaned up verify_script_model() and changed the handling of default arguments in verify_model_vm() so that passing in None will result in a torch.float being used. Let me know if the new approach looks good.


def verify_trace_model(pt_model, idata, targets):
Expand All @@ -1900,10 +1901,60 @@ def verify_trace_model(pt_model, idata, targets):
verify_model_vm(traced_model, ishapes, idata=idata, targets=targets)


def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None, targets=["llvm"]):
def convert_pt_to_tvm_type(idtype):
""" Accepts a pytorch dtype and returns string TVM dtype."""
# TVM does not support PyTorch complex dtypes
if idtype == torch.float64:
curr_dtype = "float64"
elif idtype == torch.float32:
curr_dtype = "float32"
elif idtype == torch.float16:
curr_dtype = "float16"
elif idtype == torch.bfloat16:
curr_dtype = "bfloat16"
elif idtype == torch.int64:
curr_dtype = "int64"
elif idtype == torch.int32:
curr_dtype = "int32"
elif idtype == torch.int16:
curr_dtype = "int16"
elif idtype == torch.int8:
curr_dtype = "int8"
elif idtype == torch.uint8:
curr_dtype = "uint8"
elif idtype == torch.bool:
curr_dtype = "bool"
else:
raise NotImplementedError("Unsupported dtype: {}".format(idtype))
return curr_dtype


def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llvm"]):
if not idtype:
idtype = torch.float

input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))
input_data = idata if idata else [torch.randn(shape, dtype=idtype) for shape in ishapes]
tvm_dtype = convert_pt_to_tvm_type(idtype)
input_dtypes = [tvm_dtype] * len(input_names)
input_shapes = list(zip(input_names, list(zip(ishapes, input_dtypes))))

if idata:
input_data = idata
# If no input_data provided, generate random data of specified dtype
else:
if idtype == torch.bool:
input_data = [
torch.Tensor.bool(torch.randint(low=0, high=2, size=shape)) for shape in ishapes
]
# Torch dtype can be float, complex, int, or Bool. Complex not supported, so if not float or Bool,
# dtype must be int!
elif not idtype.is_floating_point:
input_data = [
torch.randint(low=0, high=10, size=shape, dtype=idtype) for shape in ishapes
]
else:
input_data = [torch.randn(shape, dtype=idtype) for shape in ishapes]

# Compile via VM
mod, params = relay.frontend.from_pytorch(input_model, input_shapes)

Expand Down Expand Up @@ -2950,6 +3001,29 @@ def forward(self, *args):
)


@tvm.testing.uses_gpu
def test_forward_is_floating_point():
torch.set_grad_enabled(False)

class IsFloatingPoint(Module):
def forward(self, arg):
# `torch.jit.trace` cannot accept something that outputs
# a Bool, so `torch.jit.script` will be used instead
return torch.is_floating_point(arg)

targets = _get_default_vm_targets()
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float64)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float32)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float16)
# todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is implemented
# verify_script_model(IsFloatingPoint(), [(1,1)], targets, idtype=torch.bfloat16)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int64)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int32)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int16)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int8)
verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.uint8)


@tvm.testing.uses_gpu
def test_forward_traced_function():
def fn(t1, t2):
Expand Down Expand Up @@ -3425,6 +3499,7 @@ def test_fn(x, weights=None):
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
test_forward_is_floating_point()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand Down