Skip to content

Commit

Permalink
rebase for autocast updates to include device_type and dtype flags (#…
Browse files Browse the repository at this point in the history
…61002)

Summary:
Fixes #{55374}
#55374

Pull Request resolved: #61002

Reviewed By: malfet, mruberry

Differential Revision: D30016812

Pulled By: ngimel

fbshipit-source-id: 6e09a29f539d28e9aea5cd9489b1e633cc588033
  • Loading branch information
puririshi98 authored and facebook-github-bot committed Aug 11, 2021
1 parent a55cae3 commit 324673a
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 183 deletions.
11 changes: 11 additions & 0 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ thread_local int nesting = 0;

// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;

// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
at::ScalarType autocast_gpu_dtype = at::kHalf;
}

void clear_cache() {
Expand All @@ -71,6 +74,10 @@ int decrement_nesting() {
return --nesting;
}

at::ScalarType get_autocast_gpu_dtype() {
return autocast_gpu_dtype;
}

at::ScalarType get_autocast_cpu_dtype() {
return autocast_cpu_dtype;
}
Expand All @@ -82,6 +89,10 @@ void set_autocast_cpu_dtype(at::ScalarType dtype) {
autocast_cpu_dtype = dtype;
}

void set_autocast_gpu_dtype(at::ScalarType dtype) {
autocast_gpu_dtype = dtype;
}

// Overload to catch Tensor args
// TODO (possible optimization):
// Move cast_cache to an inline function in a header with cached_casts declared as
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/autocast_mode.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ TORCH_API int increment_nesting();
TORCH_API int decrement_nesting();
TORCH_API bool is_cpu_enabled();
TORCH_API void set_cpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_gpu_dtype();
TORCH_API at::ScalarType get_autocast_cpu_dtype();
TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);


namespace {
bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
return device_type == DeviceType::CUDA
Expand All @@ -38,7 +41,7 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
DeviceType device_type) {
switch (device_type) {
case DeviceType::CUDA:
return at::kHalf;
return get_autocast_gpu_dtype();
case DeviceType::CPU:
return get_autocast_cpu_dtype();
default:
Expand Down
12 changes: 9 additions & 3 deletions docs/source/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,29 @@ Automatic Mixed Precision package - torch.cuda.amp
.. automodule:: torch.cuda.amp
.. currentmodule:: torch.cuda.amp

``torch.cuda.amp`` provides convenience methods for mixed precision,
:class:`torch.cuda.amp` and :class:`torch` provide convenience methods for mixed precision,
where some operations use the ``torch.float32`` (``float``) datatype and other operations
use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions,
are much faster in ``float16``. Other ops, like reductions, often require the dynamic
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.

Ordinarily, "automatic mixed precision training" uses :class:`torch.cuda.amp.autocast` and
Ordinarily, "automatic mixed precision training" uses :class:`torch.autocast` and
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`
and `Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
However, :class:`autocast` and :class:`GradScaler` are modular, and may be used separately if desired.
However, :class:`torch.autocast` and :class:`GradScaler` are modular, and may be used separately if desired.

.. contents:: :local:

.. _autocasting:

Autocasting
^^^^^^^^^^^
.. currentmodule:: torch

.. autoclass:: autocast
:members:

.. currentmodule:: torch.cuda.amp

.. autoclass:: autocast
:members:
Expand Down
28 changes: 14 additions & 14 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,7 @@ def test_grad_scaling_autocast(self):
def run(data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api):
for i, (input, target) in enumerate(data):
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=try_scaling_api):
with torch.autocast('cuda', enabled=try_scaling_api):
output = model(input)
loss = loss_fn(output, target)
if try_scaling_api:
Expand Down Expand Up @@ -2709,7 +2709,7 @@ def cast(val, to_type):
add_kwargs = {}

self.assertFalse(torch.is_autocast_enabled())
with torch.cuda.amp.autocast():
with torch.autocast('cuda', ):
self.assertTrue(torch.is_autocast_enabled())

out_type = out_type if out_type is not None else run_as_type
Expand Down Expand Up @@ -2754,7 +2754,7 @@ def compare(first, second):
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.cuda.amp.autocast(enabled=False):
with torch.autocast('cuda', enabled=False):
self.assertFalse(torch.is_autocast_enabled())

if module is not None and hasattr(module, op):
Expand Down Expand Up @@ -2834,13 +2834,13 @@ def test_autocast_methods_expect_builtin_promote(self):
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)

def test_autocast_banned(self):
with torch.cuda.amp.autocast():
with torch.autocast('cuda'):
for op, args, module in self.autocast_lists.banned:
with self.assertRaises(RuntimeError):
getattr(module, op)(*args)

def test_autocast_ignored_types(self):
with torch.cuda.amp.autocast():
with torch.autocast('cuda'):
for ignore_type in (torch.double, torch.int32):
a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0")
Expand All @@ -2851,24 +2851,24 @@ def test_autocast_ignored_types(self):
if ignore_type is torch.double:
with self.assertRaises(RuntimeError):
torch.mm(a_ignore, c_16)
with torch.cuda.amp.autocast(enabled=False):
with torch.autocast('cuda', enabled=False):
type_no_autocast = torch.mm(a_ignore, b_ignore).dtype
self.assertTrue(torch.mm(a_ignore, b_ignore).dtype is type_no_autocast)

# Tests if CastPolicy::fp32 ops ignore double and int
with torch.cuda.amp.autocast(enabled=False):
with torch.autocast('cuda', enabled=False):
type_no_autocast = torch.pow(a_ignore, 2.0).dtype
self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast)

# Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int
with torch.cuda.amp.autocast(enabled=False):
with torch.autocast('cuda', enabled=False):
type_no_autocast = torch.sum(a_ignore).dtype
self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast)

# Tests if CastPolicy::fp32_append_dtype ops ignore double and int
# Currently, no ops belonging to this policy support integer inputs.
if ignore_type is torch.double:
with torch.cuda.amp.autocast(enabled=False):
with torch.autocast('cuda', enabled=False):
type_no_autocast = torch.norm(a_ignore).dtype
self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast)

Expand Down Expand Up @@ -2928,7 +2928,7 @@ def backward(ctx, grad):
# Sets requires_grad=False explicitly so we don't lie about expecting a gradient.
y = (0, {0: torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=False)})

with torch.cuda.amp.autocast():
with torch.autocast('cuda', ):
output = mymm(x, y, torch.float32)
self.assertTrue(output.dtype is torch.float32)
loss = output.sum()
Expand Down Expand Up @@ -2956,7 +2956,7 @@ def forward(self):
model = Model()
model_jit_script = torch.jit.script(model)

with torch.cuda.amp.autocast(True):
with torch.autocast('cuda', enabled=True):
model()
model_jit_script()

Expand Down Expand Up @@ -3006,7 +3006,7 @@ def test_autocast_rnn(self):
device="cuda", dtype=hidden_dtype)
h = (h, c)

with torch.cuda.amp.autocast():
with torch.autocast('cuda', ):
out, h_out = rnn(x, h)
out = out.data if input_layout == "packed" else out
self.assertEqual(out.dtype, torch.float16)
Expand Down Expand Up @@ -3049,7 +3049,7 @@ def test_autocast_cache_leak(self):
linear = torch.nn.Linear(10, 10).to('cuda')
data = torch.randn(1, 10, device='cuda')

with torch.cuda.amp.autocast():
with torch.autocast('cuda', ):
with torch.no_grad():
out = linear(data)
first_iter_mem = torch.cuda.memory_allocated()
Expand All @@ -3062,7 +3062,7 @@ def test_autocast_checkpointing(self):
torch.nn.Linear(8, 8),
torch.nn.Linear(8, 8)).cuda()
input = torch.rand((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
with torch.cuda.amp.autocast():
with torch.autocast('cuda', ):
output = checkpoint_sequential(model, 2, input)
self.assertTrue(output.requires_grad)
self.assertTrue(output.dtype is torch.float16)
Expand Down
42 changes: 42 additions & 0 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,49 @@ def test_no_new_bindings(self):
"unify_type_list",
"Use",
"Value",
"autocast_decrement_nesting",
"autocast_increment_nesting",
"clear_autocast_cache",
"cpp",
"default_generator",
"device",
"dtype",
"finfo",
"fork",
"get_default_dtype",
"get_num_interop_threads",
"get_num_threads",
"has_cuda",
"has_cudnn",
"has_lapack",
"has_mkl",
"has_mkldnn",
"has_mlc",
"has_openmp",
"iinfo",
"import_ir_module",
"import_ir_module_from_buffer",
"init_num_threads",
"is_anomaly_enabled",
"is_autocast_enabled",
"is_grad_enabled",
"layout",
"memory_format",
"merge_type_from_type_comment",
"parse_ir",
"parse_schema",
"parse_type_comment",
"qscheme",
"set_anomaly_enabled",
"set_autocast_enabled",
'set_autocast_gpu_dtype',
'get_autocast_gpu_dtype',
"set_flush_denormal",
"set_num_interop_threads",
"set_num_threads",
"unify_type_list",
"vitals_enabled",

"wait",
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,9 @@ def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ...
def get_autocast_gpu_dtype() -> _dtype: ...
def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...
Expand Down
2 changes: 1 addition & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import textwrap
import ctypes
import warnings

from .autocast_mode import autocast
if sys.version_info < (3,):
raise Exception("Python 2 has reached end-of-life and is no longer supported by PyTorch.")

Expand Down
Loading

0 comments on commit 324673a

Please sign in to comment.