From 76b3cca5cc9a18a56db8107d2f6c8e94851bb85c Mon Sep 17 00:00:00 2001 From: Kevin Stephano Date: Fri, 1 Jul 2022 20:31:52 -0700 Subject: [PATCH] Add parsing support for `_to_copy` to handle AMP casts. (#1756) 1. Add support for _to_copy() to support AMP casts. 2. refactored cast, accept none for dtype 3. python tests Co-authored-by: jjsjann123 --- test/test_jit_cuda_fuser.py | 27 +++++ torch/csrc/jit/codegen/cuda/parser.cpp | 106 +++++++++++++++--- .../csrc/jit/codegen/cuda/type_inference.cpp | 16 ++- 3 files changed, 130 insertions(+), 19 deletions(-) diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 5d3ffe7b5207c..dee87fae0935f 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -4390,6 +4390,33 @@ def t(x): t_jit = torch.jit.script(t) self._run_helper(t_jit, t, x) + + @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, + "Requires fusion optimization pass to be effective") + def test_to_copy(self): + x = torch.randn(4, 2, device="cuda") + + with nvfuser_singleton_fusion(True): + def t(x, dtype : torch.dtype): + o = torch.ops.aten._to_copy(x, dtype=dtype) + return o + + t.__disable_jit_function_caching__ = True + + t_jit = torch.jit.script(t) + for dtype in [torch.float16, torch.bool, torch.float64]: + self._run_helper(t_jit, t, x, dtype) + + def t_none(x): + with torch.jit.strict_fusion(): + o = torch.ops.aten._to_copy(x, dtype=None) + return o + + t_jit_none = torch.jit.script(t_none) + self._run_helper(t_jit_none, t_none, x) + + @unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now") @unittest.skipIf(not RUN_NVFUSER, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 5d05f37ad3a36..f857d3109f20a 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -72,6 +72,32 @@ const auto& profileFailedAttr = Symbol::attr("profile_failed"); typedef Val* CgValue; typedef Expr* CgOp; +Val* castTensoToDtype(CgValue self, JitValue* cast_val) { + auto cast_ival = toIValue(cast_val); + // we need static type for cast + TORCH_INTERNAL_ASSERT(cast_ival.has_value()); + if (cast_ival->isInt()) { + auto dtype = cast_ival->toScalarType(); + + // We want to keep our internal fusion math in FP32 + // Shape Inference will continue to propagate the right + // type to outputs unchanged. + if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) { + dtype = at::ScalarType::Float; + } + + return castOp(aten_to_data_type(dtype), self); + } else { + TORCH_INTERNAL_ASSERT( + cast_ival->isNone(), + "unrecognized dtype option, expect 'int' but got: ", + cast_ival->tagKind()); + + // return a copy if dtype is `None` + return set(self); + } +} + bool isReductionNonCompatibleTensor( const std::shared_ptr& tensor_type) { return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type); @@ -2704,10 +2730,9 @@ class IrParser { } } - // Limiting aten::to implementation to only change the dtype of a tensor { auto ptr_op = getOperatorForLiteral( - "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"); REGISTER_PARSE_RULE( ptr_op, { @@ -2718,22 +2743,59 @@ class IrParser { auto self = list_val.front(); list_val.pop_front(); - // we need static type for cast - TORCH_INTERNAL_ASSERT( - node->input(1)->node()->kind() == prim::Constant); - auto dtype = toIValue(node->input(1))->toScalarType(); - - // We want to keep our internal fusion math in FP32 - // Shape Inference will continue to propagate the right - // type to outputs unchanged. - if (dtype == at::ScalarType::Half) { - dtype = at::ScalarType::Float; + auto out = castTensoToDtype(self, node->input(1)); + + value_map.emplace( + node->output()->unique(), ValueHolder(out, format)); + }, + [](const Node* node) -> bool { + if (!isInputNonSizeZeroTensor(node)) { + return false; + } + if (node->inputs()[1]->node()->kind() != prim::Constant) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[2]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[3]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; + } + // we do not support explicit memory_format on output + if (!node->inputs()[4]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; } - if (dtype == at::ScalarType::BFloat16) { - dtype = at::ScalarType::Float; + // we do not support explicit memory_format on output + if (!node->inputs()[6]->type()->isSubtypeOf( + static_cast(NoneType::get()))) { + return false; } + return true; + }, + nullptr); + } + + // Limiting aten::to implementation to only change the dtype of a tensor + { + auto ptr_op = getOperatorForLiteral( + "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor"); + REGISTER_PARSE_RULE( + ptr_op, + { + MemoryFormat format; + std::list list_val; + std::tie(format, list_val) = getConsistentValues( + c10::nullopt, value_map[node->inputs()[0]->unique()]); + auto self = list_val.front(); + list_val.pop_front(); + + auto out = castTensoToDtype(self, node->input(1)); - auto out = castOp(aten_to_data_type(dtype), self); value_map.emplace( node->output()->unique(), ValueHolder(out, format)); }, @@ -4186,6 +4248,20 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { return true; } + static auto to_copy_schema = + getOperatorForLiteral( + "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor") + ->schema(); + if (node->matches(to_copy_schema)) { + switch (offset) { + case 1: + profileInt(pr, node, offset); + return true; + default: + return false; + } + } + static auto to_dtype_schema = getOperatorForLiteral( "aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor") diff --git a/torch/csrc/jit/codegen/cuda/type_inference.cpp b/torch/csrc/jit/codegen/cuda/type_inference.cpp index 58f2187ea1cc9..0bc821e024c3d 100644 --- a/torch/csrc/jit/codegen/cuda/type_inference.cpp +++ b/torch/csrc/jit/codegen/cuda/type_inference.cpp @@ -462,12 +462,20 @@ class NaiveTypePropagator { type0->withScalarType(type1->scalarType()), node); break; } - case aten::to: { + case aten::to: + case aten::_to_copy: { const auto type0 = getInputTensorType(node, 0); const auto out_dtype = toIValue(node->input(1)); - TORCH_CHECK(out_dtype, "No output type specified"); - copyScalarTypeAndDeviceToOutput( - type0->withScalarType(out_dtype->toScalarType()), node); + if (out_dtype.has_value() && out_dtype->isInt()) { + copyScalarTypeAndDeviceToOutput( + type0->withScalarType(out_dtype->toScalarType()), node); + } else { + TORCH_CHECK( + !out_dtype.has_value() || out_dtype->isNone(), + "dtype for cast unrecognized ", + out_dtype->tagKind()); + copyScalarTypeAndDeviceToOutput(type0, node); + } break; } case prim::add_optional: {