diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2dda675c74f5..74d9c78d0e3d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1889,9 +1889,10 @@ def _get_default_vm_targets(): return [tgt for (tgt, _) in tvm.testing.enabled_targets()] -def verify_script_model(pt_model, ishapes, targets): +def verify_script_model(pt_model, ishapes, targets, idtype=None): script_module = torch.jit.script(pt_model) - verify_model_vm(script_module, ishapes, targets=targets) + + verify_model_vm(script_module, ishapes, idtype=idtype, targets=targets) def verify_trace_model(pt_model, idata, targets): @@ -1900,10 +1901,60 @@ def verify_trace_model(pt_model, idata, targets): verify_model_vm(traced_model, ishapes, idata=idata, targets=targets) -def verify_model_vm(input_model, ishapes, idtype=torch.float, idata=None, targets=["llvm"]): +def convert_pt_to_tvm_type(idtype): + """ Accepts a pytorch dtype and returns string TVM dtype.""" + # TVM does not support PyTorch complex dtypes + if idtype == torch.float64: + curr_dtype = "float64" + elif idtype == torch.float32: + curr_dtype = "float32" + elif idtype == torch.float16: + curr_dtype = "float16" + elif idtype == torch.bfloat16: + curr_dtype = "bfloat16" + elif idtype == torch.int64: + curr_dtype = "int64" + elif idtype == torch.int32: + curr_dtype = "int32" + elif idtype == torch.int16: + curr_dtype = "int16" + elif idtype == torch.int8: + curr_dtype = "int8" + elif idtype == torch.uint8: + curr_dtype = "uint8" + elif idtype == torch.bool: + curr_dtype = "bool" + else: + raise NotImplementedError("Unsupported dtype: {}".format(idtype)) + return curr_dtype + + +def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=["llvm"]): + if not idtype: + idtype = torch.float + input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] - input_shapes = list(zip(input_names, ishapes)) - input_data = idata if idata else [torch.randn(shape, dtype=idtype) for shape in ishapes] + tvm_dtype = convert_pt_to_tvm_type(idtype) + input_dtypes = [tvm_dtype] * len(input_names) + input_shapes = list(zip(input_names, list(zip(ishapes, input_dtypes)))) + + if idata: + input_data = idata + # If no input_data provided, generate random data of specified dtype + else: + if idtype == torch.bool: + input_data = [ + torch.Tensor.bool(torch.randint(low=0, high=2, size=shape)) for shape in ishapes + ] + # Torch dtype can be float, complex, int, or Bool. Complex not supported, so if not float or Bool, + # dtype must be int! + elif not idtype.is_floating_point: + input_data = [ + torch.randint(low=0, high=10, size=shape, dtype=idtype) for shape in ishapes + ] + else: + input_data = [torch.randn(shape, dtype=idtype) for shape in ishapes] + # Compile via VM mod, params = relay.frontend.from_pytorch(input_model, input_shapes) @@ -2950,6 +3001,29 @@ def forward(self, *args): ) +@tvm.testing.uses_gpu +def test_forward_is_floating_point(): + torch.set_grad_enabled(False) + + class IsFloatingPoint(Module): + def forward(self, arg): + # `torch.jit.trace` cannot accept something that outputs + # a Bool, so `torch.jit.script` will be used instead + return torch.is_floating_point(arg) + + targets = _get_default_vm_targets() + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float64) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float32) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float16) + # todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is implemented + # verify_script_model(IsFloatingPoint(), [(1,1)], targets, idtype=torch.bfloat16) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int64) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int32) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int16) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int8) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.uint8) + + @tvm.testing.uses_gpu def test_forward_traced_function(): def fn(t1, t2): @@ -3425,6 +3499,7 @@ def test_fn(x, weights=None): test_forward_addcdiv() test_forward_addcmul() test_forward_true_divide() + test_forward_is_floating_point() test_forward_clone() test_forward_softplus() test_forward_softsign()