Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fix for rnn_loss.py #1183

Merged
merged 2 commits into from
Apr 30, 2023
Merged

Small fix for rnn_loss.py #1183

merged 2 commits into from
Apr 30, 2023

Conversation

yfyeung
Copy link
Contributor

@yfyeung yfyeung commented Apr 30, 2023

No description provided.

@danpovey
Copy link
Collaborator

danpovey commented Apr 30, 2023

This is another fix for a memory issue in rnnt_loss.py, similar to #1177- the problem is that a too-large zero tensor is used in the backward pass, which can lead to OOM errors in training for large batch sizes. I discovered the problem using a version of PyTorch that was built with debug info. I ran a training using large batch size (1450) and 1 job, using gdb --args.

 gdb --args python3 ./pruned_transducer_stateless7/train.py --master-port 43021 --world-size 1 --num-workers 0 --num-epochs 24 --full-libri 0 --exp\
-dir pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full --max-duration 1450 --use-fp16 True --decoder-dim 512 --joiner-dim 512 --start-epoch=1 --base-lr=0.035
(gdb) catch throw
(gdb) r

When it stopped at the point where it was about to issue this error:

RuntimeError: CUDA out of memory. Tried to allocate 4.89 GiB (GPU 0; 31.75 GiB total capacity; 23.75 GiB already allocated; 4.71 GiB free; 25.86 GiB reserved in total by PyTorch)
<snip>
2023-04-28 02:09:35,802 INFO [checkpoint.py:75] Saving checkpoint to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/bad-model-0.pt
2023-04-28 02:09:45,045 INFO [train.py:1299] Saving batch to pruned_transducer_stateless7/scaled_adam_exp1117_4job_md1250_full/batch-d9cd8db7-a730-5db6-e534-90be42e420b6.pt
2023-04-28 02:09:45,224 INFO [train.py:1305] features shape: torch.Size([81, 1777, 80])

... I went to stack frame 65 which had size info:

#62 0x00007fffcb7c63ff in c10::Dispatcher::callWithDispatchKey<at::Tensor, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, \
long, at::Tensor const&, bool)> const&, c10::DispatchKey, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (this=0x7fffd86eada0 <c10::Dispatcher::singleton()::_singleton>, op=...,
    dispatchKey=c10::DispatchKey::AutogradCUDA) at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:376
#63 0x00007fffcb7204c3 in c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, long, at::Tenso\
r const&, bool)> const&, at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (this=0x7fffd86eada0 <c10::Dispatcher::singleton()::_singleton>, op=...)
    at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:387
#64 0x00007fffcb6ae3a4 in c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool)>::call(at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) const (
    this=0x7fffd86f4d00 <at::gather_backward(at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool)::op>, args#0=..., args#1=..., args#2=2, args#3=..., args#4=false)
    at /star-fj/fangjun/open-source/pytorch/aten/src/ATen/core/dispatch/Dispatcher.h:327
#65 0x00007fffcb655052 in at::gather_backward (grad=..., self=..., dim=2, index=..., sparse_grad=false) at /star-fj/fangjun/open-source/pytorch/build/aten/src/ATen/Functions.cpp:8184
#66 0x00007fffcd4e3f17 in torch::autograd::generated::GatherBackward::apply (this=0x9b102f40, grads=...) at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/generated/Functions.cpp:1563
#67 0x00007fffcdc8ef62 in torch::autograd::Node::operator() (this=0x9b102f40, inputs=...) at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/function.h:155
#68 0x00007fffcdc89027 in torch::autograd::call_function (graph_task=std::shared_ptr<torch::autograd::GraphTask> (use count 3, weak count 4) = {...}, func=0x9b102f40, inputBuffer=...)
    at /star-fj/fangjun/open-source/pytorch/torch/csrc/autograd/engine.cpp:676
#69 0x00007fffcdc89582 in torch::autograd::Engine::evaluate_function (this=0x7fffdaef21c0 <torch::autograd::python::PythonEngine::get_python_engine()::engine>,
    graph_task=std::shared_ptr<torch::autograd::GraphTask> (use count 3, weak count 4) = {...}, func=0x9b102f40, inputs=..., cpu_ready_queue=std::shared_ptr<torch::autograd::ReadyQueue> (use count 2, weak count 0) = {..\
.})

and printed out the size info:

(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[0]
$36 = 81
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[1]
$37 = 443
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[2]
$38 = 5
(gdb) p ((long*) index.impl_.target_.sizes_.BeginX)[3]
$39 = 512
(gdb) p ((long[4]) self.impl_.target_.sizes_.BeginX)
$40 = {2574253384, 2574253416, 2574253424, 81}
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[0]
$41 = 81 # B
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[1]
$42 = 443 # T
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[2]
$43 = 143 # S
(gdb) p ((long*) self.impl_.target_.sizes_.BeginX)[3]
$44 = 512 # C

(I figured out these expression from looking at the types printed out for these variables.)
From the size info I figured out which part of the code it related to, it was a torch.gather expression.

@danpovey danpovey merged commit 1d4dd96 into k2-fsa:master Apr 30, 2023
@yfyeung yfyeung deleted the patch-1 branch April 30, 2023 05:18
@csukuangfj csukuangfj mentioned this pull request Apr 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants