Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak issue in torch_fx tests #18547

Merged
merged 5 commits into from
Aug 29, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Aug 9, 2022

What does this PR do?

Question: On GPU VMs, we have to use spawn, see here. However, it still hangs with spawn (I can't figure out this yet). Should we have 2 branches: one using new process for CPU VM (on CircleCI), and another one using the original approach (no new process - for GPU VM, like on scheduled CI)?

I might have a solution! --> send the model to the child process in CPU and send to CUDA device there.

I am going to try torch.multiprocessing first. not working neither


Run torch_fx tests in a spawn process to avoid memory issue.

  • See this comment for the effect
  • The reason to use JoinableQueue instead of Queue for the outputs:

https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 10, 2022

  • without new process

    • 2~3 minutes for 100 runs
    • 15 MB leak per run
  • with fork

    • 5 minutes for 100 runs
    • 1 MB leak per run
    • hangs if MKL_NUM_THREADS > 1
  • with spawn

    • 30 minutes for 100 runs
    • 1 MB leak per run

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 10, 2022

When , using the new process approach, in some cases, setting ulimit -n 2048 is necessary.
(For example, running the same test with a loop)

Otherwise, we might get the following error:

tests/models/bart/test_modeling_bart.py::BartModelTest::test_torch_fx Traceback (most recent call last):
  File "/usr/lib/python3.9/multiprocessing/queues.py", line 245, in _feed
  File "/usr/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps
  File "/home/yih_dar_huggingface_co/.local/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 358, in reduce_storage
RuntimeError: unable to open shared memory object </torch_46201_690006289_939> in read-write mode: Too many open files (24)

More details:

>   ???

tests/test_modeling_common.py:769: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/test_modeling_common.py:866: in _create_and_check_torch_fx_tracing
    ???
/usr/lib/python3.9/multiprocessing/process.py:121: in start
    ???
/usr/lib/python3.9/multiprocessing/context.py:277: in _Popen
    ???
/usr/lib/python3.9/multiprocessing/popen_fork.py:19: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <multiprocessing.popen_fork.Popen object at 0x7fa12a499820>, process_obj = <ForkProcess name='ForkProcess-10' parent=46201 initial>

>   ???
E   OSError: [Errno 24] Too many open files

/usr/lib/python3.9/multiprocessing/popen_fork.py:64: OSError

This seems to relate to torch multiprocessing: https://discuss.pytorch.org/t/runtimeerror-unable-to-open-shared-memory-object-depending-on-the-model/116090

Another related issue (not torch): lava-nc/lava#71

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 10, 2022

With GPU, we have to use spawn, otherwise

Process ForkProcess-1:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/transformers/tests/test_modeling_common.py", line 143, in _run_torch_jit
    model, input_names, filtered_inputs = in_queue.get(timeout=30)
  File "/usr/lib/python3.8/multiprocessing/queues.py", line 116, in get
    return _ForkingPickler.loads(res)
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/reductions.py", line 112, in rebuild_cuda_tensor
    torch.cuda._lazy_init()
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py", line 207, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

# Looks like `MKL_NUM_THREADS > 1` with `fork` will hang if the traced/scripted models call inputs.
# Let's use `spawn` to have a new clean process.
# (we can even use `spawn` on scheduled CI but `fork` on CircleCI if necessary)
ctx = multiprocessing.get_context("fork")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not working with GPU/CUDA - need to be spawn in this case.


process = ctx.Process(target=_run_torch_jit, args=(input_queue, output_queue))
process.start()
traced_model, traced_output, scripted_output, error = output_queue.get(timeout=30)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still hangs on GPU VM, even with spawn

@ydshieh ydshieh requested review from LysandreJik, michaelbenayoun and sgugger and removed request for michaelbenayoun August 10, 2022 09:09
# To avoid the child process hanging on the line `traced_model = symbolic_trace(model, input_names)`.
# We will run `model.to(torch_device)` in the child process instead.
if torch_device != "cpu":
model.to("cpu")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know the exact reason, but passing model in CUDA via Queue will cause tracing issue. Let's pass it in CPU and send to CUDA device in the subprocess.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot really use multiprocessing with CUDA, they don't play along well together.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelbenayoun would be the best person to review this :-)

@michaelbenayoun
Copy link
Member

I think it's safe to only run those tests on CPU.
Also, when running locally it takes ~ 1 min (althought I agree my machine might be more powerful).

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments.
Also be careful because some models also have their "custom" implementation of the test.

@@ -138,6 +138,34 @@ def _config_zero_init(config):
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"


def _run_torch_jit(in_queue, out_queue):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't "like" here is that now a test would fail for torchscripting before failing for output mismatch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I could even running the match in this function, return the matching results to the parent process, and fail the tests there - if you prefer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to only run those tests on CPU. Also, when running locally it takes ~ 1 min (althought I agree my machine might be more powerful).

Do you run against my branch?

Copy link
Member

@michaelbenayoun michaelbenayoun Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would be perfect!
I did not run anything on your branch, I was talking in general those tests are not that long, and never fail, but my machine is most likely better than the one running the CI!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is indeed fast. This PR is not addressing the time issue, but the memory issue. Currently, each call to test_torch_fx will increase the memory usage by ~15MB.

The time issue comes after the fix, as we create new processes to run (part of) the code. Using fork is fine, but spawn will be quite slow. But currently, this time issue is insignificant (and spawn is only used on GPU CI running on our own runners - therefore no a real constraint).


model_output = model(**filtered_inputs)

# Note: `MKL_NUM_THREADS > 1` with `fork` will hang if the traced/scripted models call inputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it behave locally?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CircleCI, we have MKL_NUM_THREADS = 1 (and we are on CPU) --> no issue.
On scheduled CI, we have MKL_NUM_THREADS = 8, but the device is GPU, so we use spawn --> no issue.

Locally, it seems that if we don't set MKL_NUM_THREADS explicitly, it will behave like MKL_NUM_THREADS > 1, and hangs when running on CPU. --> This is not desirable, I will try if I can set it to 1 temporarily inside this test.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 12, 2022

@michaelbenayoun

I move (almost) the whole testing logic to the child process. On more advantage here is to create the model in the child process, so we don't need to pass it between the process.

Now running 100 times, we have only (per run) 0.15 MB increase of memory usage.

@ydshieh ydshieh force-pushed the fix_torch_fx_test_mem_issue branch from c140d77 to 3393331 Compare August 12, 2022 11:28
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 17, 2022

@michaelbenayoun You are right, some model overwrites _create_and_check_torch_fx_tracing. This won't fail this PR however: those models will just run the test_torch_fx* tests in the current manner (i.e. not in the child process). I will take a look if those overwritting are necessary. In any case, we can merge this PR as it is (if you are happy with it), and I will work on those models later.

@michaelbenayoun
Copy link
Member

I think it's okay now with the changes you've made!

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 22, 2022

I think it's okay now with the changes you've made!

Would love to have a approval from you, @michaelbenayoun.
But no need to rush - as long as you finally happy with the change and click the button.

@ydshieh ydshieh force-pushed the fix_torch_fx_test_mem_issue branch from 3393331 to 3dfa380 Compare August 22, 2022 12:42
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 22, 2022

ready for @sgugger and/or @LysandreJik to have a final check 🚀

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! Thanks for working on this, @ydshieh

@@ -138,6 +138,159 @@ def _config_zero_init(config):
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"


def _run_torch_jit(in_queue, out_queue):
import traceback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be at the top of the file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has to be outside the class (i.e. can't be the method in the class), otherwise multiprocessing has issue with pickling object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

tests/test_modeling_common.py Outdated Show resolved Hide resolved
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 24, 2022

I will merge this afternoon, after adding a short command in _create_and_check_torch_fx_tracing explaining why we need this change, with a link to #18525

@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 25, 2022

Hi @michaelbenayoun, I just saw that I fixed a similar issue a few months ago

# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.

(for _create_and_check_torchscript). I am going to change this PR to simply apply that fix. Is it OK for you?

@ydshieh ydshieh force-pushed the fix_torch_fx_test_mem_issue branch from 489e80c to 7d03df4 Compare August 29, 2022 08:24
@ydshieh
Copy link
Collaborator Author

ydshieh commented Aug 29, 2022

Changed the PR to simply call clear_torch_jit_class_registry. Test failure is irrelevant to this PR - merge now.

@ydshieh ydshieh changed the title Run torch_fx tests in a spawn process to avoid memory issue Fix memory leak issue in torch_fx tests Aug 29, 2022
@ydshieh ydshieh merged commit 8b67f20 into huggingface:main Aug 29, 2022
@ydshieh ydshieh deleted the fix_torch_fx_test_mem_issue branch August 29, 2022 09:43
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Co-authored-by: Lysandre Debut <hi@lysand.re>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants