Skip to content

Commit

Permalink
Add parsing support for _to_copy to handle AMP casts. (#1756)
Browse files Browse the repository at this point in the history
1. Add support for _to_copy() to support AMP casts.
2. refactored cast, accept none for dtype
3. python tests

Co-authored-by: jjsjann123 <jiej@nvidia.com>
  • Loading branch information
kevinstephano and jjsjann123 authored Jul 2, 2022
1 parent ef04f6c commit 76b3cca
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 19 deletions.
27 changes: 27 additions & 0 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4390,6 +4390,33 @@ def t(x):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)


@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_copy(self):
x = torch.randn(4, 2, device="cuda")

with nvfuser_singleton_fusion(True):
def t(x, dtype : torch.dtype):
o = torch.ops.aten._to_copy(x, dtype=dtype)
return o

t.__disable_jit_function_caching__ = True

t_jit = torch.jit.script(t)
for dtype in [torch.float16, torch.bool, torch.float64]:
self._run_helper(t_jit, t, x, dtype)

def t_none(x):
with torch.jit.strict_fusion():
o = torch.ops.aten._to_copy(x, dtype=None)
return o

t_jit_none = torch.jit.script(t_none)
self._run_helper(t_jit_none, t_none, x)


@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down
106 changes: 91 additions & 15 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,32 @@ const auto& profileFailedAttr = Symbol::attr("profile_failed");
typedef Val* CgValue;
typedef Expr* CgOp;

Val* castTensoToDtype(CgValue self, JitValue* cast_val) {
auto cast_ival = toIValue(cast_val);
// we need static type for cast
TORCH_INTERNAL_ASSERT(cast_ival.has_value());
if (cast_ival->isInt()) {
auto dtype = cast_ival->toScalarType();

// We want to keep our internal fusion math in FP32
// Shape Inference will continue to propagate the right
// type to outputs unchanged.
if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16) {
dtype = at::ScalarType::Float;
}

return castOp(aten_to_data_type(dtype), self);
} else {
TORCH_INTERNAL_ASSERT(
cast_ival->isNone(),
"unrecognized dtype option, expect 'int' but got: ",
cast_ival->tagKind());

// return a copy if dtype is `None`
return set(self);
}
}

bool isReductionNonCompatibleTensor(
const std::shared_ptr<c10::TensorType>& tensor_type) {
return is_zero_dim_tensor(tensor_type) || is_zero_sized_tensor(tensor_type);
Expand Down Expand Up @@ -2704,10 +2730,9 @@ class IrParser {
}
}

// Limiting aten::to implementation to only change the dtype of a tensor
{
auto ptr_op = getOperatorForLiteral(
"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
"aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
Expand All @@ -2718,22 +2743,59 @@ class IrParser {
auto self = list_val.front();
list_val.pop_front();

// we need static type for cast
TORCH_INTERNAL_ASSERT(
node->input(1)->node()->kind() == prim::Constant);
auto dtype = toIValue(node->input(1))->toScalarType();

// We want to keep our internal fusion math in FP32
// Shape Inference will continue to propagate the right
// type to outputs unchanged.
if (dtype == at::ScalarType::Half) {
dtype = at::ScalarType::Float;
auto out = castTensoToDtype(self, node->input(1));

value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
[](const Node* node) -> bool {
if (!isInputNonSizeZeroTensor(node)) {
return false;
}
if (node->inputs()[1]->node()->kind() != prim::Constant) {
return false;
}
// we do not support explicit memory_format on output
if (!node->inputs()[2]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
return false;
}
// we do not support explicit memory_format on output
if (!node->inputs()[3]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
return false;
}
// we do not support explicit memory_format on output
if (!node->inputs()[4]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
return false;
}
if (dtype == at::ScalarType::BFloat16) {
dtype = at::ScalarType::Float;
// we do not support explicit memory_format on output
if (!node->inputs()[6]->type()->isSubtypeOf(
static_cast<c10::TypePtr>(NoneType::get()))) {
return false;
}
return true;
},
nullptr);
}

// Limiting aten::to implementation to only change the dtype of a tensor
{
auto ptr_op = getOperatorForLiteral(
"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor");
REGISTER_PARSE_RULE(
ptr_op,
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
c10::nullopt, value_map[node->inputs()[0]->unique()]);
auto self = list_val.front();
list_val.pop_front();

auto out = castTensoToDtype(self, node->input(1));

auto out = castOp(aten_to_data_type(dtype), self);
value_map.emplace(
node->output()->unique(), ValueHolder(out, format));
},
Expand Down Expand Up @@ -4186,6 +4248,20 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
return true;
}

static auto to_copy_schema =
getOperatorForLiteral(
"aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor")
->schema();
if (node->matches(to_copy_schema)) {
switch (offset) {
case 1:
profileInt(pr, node, offset);
return true;
default:
return false;
}
}

static auto to_dtype_schema =
getOperatorForLiteral(
"aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor")
Expand Down
16 changes: 12 additions & 4 deletions torch/csrc/jit/codegen/cuda/type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,20 @@ class NaiveTypePropagator {
type0->withScalarType(type1->scalarType()), node);
break;
}
case aten::to: {
case aten::to:
case aten::_to_copy: {
const auto type0 = getInputTensorType(node, 0);
const auto out_dtype = toIValue(node->input(1));
TORCH_CHECK(out_dtype, "No output type specified");
copyScalarTypeAndDeviceToOutput(
type0->withScalarType(out_dtype->toScalarType()), node);
if (out_dtype.has_value() && out_dtype->isInt()) {
copyScalarTypeAndDeviceToOutput(
type0->withScalarType(out_dtype->toScalarType()), node);
} else {
TORCH_CHECK(
!out_dtype.has_value() || out_dtype->isNone(),
"dtype for cast unrecognized ",
out_dtype->tagKind());
copyScalarTypeAndDeviceToOutput(type0, node);
}
break;
}
case prim::add_optional: {
Expand Down

0 comments on commit 76b3cca

Please sign in to comment.