Skip to content

Commit 57bffc3

Browse files
malfetptrblck
andauthored
Disable autocast cache for tensor views as fix for pytorch#48049 (pytorch#48696) (pytorch#48936)
Co-authored-by: pbialecki <pbialecki@nvidia.com>
1 parent 661d1a0 commit 57bffc3

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

aten/src/ATen/autocast_mode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
6868
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
6969
// Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
7070
// See cached_casts declaration above for detailed strategy.
71-
bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf());
71+
bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view());
7272
if (can_try_cache) {
7373
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
7474
if (it != cached_casts.end()) {

test/test_cuda.py

+16
Original file line numberDiff line numberDiff line change
@@ -2706,6 +2706,22 @@ def test_autocast_rnn(self):
27062706
for grad, grad_control in zip(grads, grads_control):
27072707
self.assertEqual(grad.half(), grad_control)
27082708

2709+
def test_autocast_cache_leak(self):
2710+
# Reported at https://github.com/pytorch/pytorch/issues/48049
2711+
# Test is used to check, if autocast recaches the same parameters
2712+
# when executed in a `torch.no_grad()` block.
2713+
2714+
linear = torch.nn.Linear(10, 10).to('cuda')
2715+
data = torch.randn(1, 10, device='cuda')
2716+
2717+
with torch.cuda.amp.autocast():
2718+
with torch.no_grad():
2719+
out = linear(data)
2720+
first_iter_mem = torch.cuda.memory_allocated()
2721+
for _ in range(3):
2722+
out = linear(data)
2723+
self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())
2724+
27092725
@slowTest
27102726
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
27112727
def test_max_large_axis(self):

0 commit comments

Comments
 (0)