Skip to content

[Core][Tune] Ray tune cannot be used with pytorch-lightning 1.7.0 due to processes spawned with fork. #27493

@Alfredvc

Description

@Alfredvc

What happened + What you expected to happen

As part of Add support for DDP fork included in pytorch-lightning 1.7.0 calls to:

torch.cuda.device_count()
torch.cuda.is_available()

in the pytorch lightning codebase were replaced with new functions:

pytorch_lightning.utilities.device_parser.num_cuda_devices()
pytorch_lightning.utilities.device_parser.is_cuda_available()

These functions internally create a multiprocessing.Pool with fork

with multiprocessing.get_context("fork").Pool(1) as pool:
        return pool.apply(torch.cuda.device_count)

This call waits forever when run inside an Actor.

(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
(train pid=139, ip=172.22.0.3) 		self._bootstrap_inner()
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
(train pid=139, ip=172.22.0.3) 		self.run()
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 277, in run
(train pid=139, ip=172.22.0.3) 		self._entrypoint()
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 349, in entrypoint
(train pid=139, ip=172.22.0.3) 		return self._trainable_func(
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(train pid=139, ip=172.22.0.3) 		return method(self, *_args, **_kwargs)
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 645, in _trainable_func
(train pid=139, ip=172.22.0.3) 		output = fn()
(train pid=139, ip=172.22.0.3) 	File "test.py", line 9, in train
(train pid=139, ip=172.22.0.3) 		pl.Trainer(
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/argparse.py", line 345, in insert_env_defaults
(train pid=139, ip=172.22.0.3) 		return fn(self, **kwargs)
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 537, in __init__
(train pid=139, ip=172.22.0.3) 		self._setup_on_init(num_sanity_val_steps)
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 618, in _setup_on_init
(train pid=139, ip=172.22.0.3) 		self._log_device_info()
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1739, in _log_device_info
(train pid=139, ip=172.22.0.3) 		if CUDAAccelerator.is_available():
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/accelerators/cuda.py", line 91, in is_available
(train pid=139, ip=172.22.0.3) 		return device_parser.num_cuda_devices() > 0
(train pid=139, ip=172.22.0.3) 	File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/device_parser.py", line 346, in num_cuda_devices
(train pid=139, ip=172.22.0.3) 		return pool.apply(torch.cuda.device_count)
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/pool.py", line 736, in __exit__
(train pid=139, ip=172.22.0.3) 		self.terminate()
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/pool.py", line 654, in terminate
(train pid=139, ip=172.22.0.3) 		self._terminate()
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/util.py", line 224, in __call__
(train pid=139, ip=172.22.0.3) 		res = self._callback(*self._args, **self._kwargs)
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/pool.py", line 729, in _terminate_pool
(train pid=139, ip=172.22.0.3) 		p.join()
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/process.py", line 149, in join
(train pid=139, ip=172.22.0.3) 		res = self._popen.wait(timeout)
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 47, in wait
(train pid=139, ip=172.22.0.3) 		return self.poll(os.WNOHANG if timeout == 0.0 else 0)
(train pid=139, ip=172.22.0.3) 	File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 27, in poll
(train pid=139, ip=172.22.0.3) 		pid, sts = os.waitpid(self.pid, flag)

This is a critical breaking change given that pytorch_lightning.Trainer calls these methods and therefore cannot be used.

The reproduction script below always hangs. However during my experimentation I found that creating a minimal reproduction script was difficult. Sometimes a script will work, and fail when re-running it. Sometimes changing a seemingly unrelated line of code makes a working script fail. I haven't dived deep enough into the Ray codebase to understand why this is the case.

For my larger projects ray-tune simply cannot be used with pytorch-lightning 1.7.0 as these calls aways hang. My current workaround is to monkeypatch torch.multiprocessing.get_all_start_methods.

    patched_start_methods = [m for m in torch.multiprocessing.get_all_start_methods() if m != "fork"]
    torch.multiprocessing.get_all_start_methods = lambda: patched_start_methods

As far as I can tell it is known that ray does not work with forked processes https://discuss.ray.io/t/best-solution-to-have-multiprocess-working-in-actor/2165/8. However given that pytorch-lightning is a such a widely used library in the ML ecosystem this issue may be worth looking into.

Versions / Dependencies

ray-tune 1.13.0
pytorch 1.12.0
pytorch-lightning 1.7.0
python 3.8.10
OS: Ubuntu 20.04.4 LTS

Reproduction script

import pytorch_lightning as pl
from ray import tune


def train(config):
    pl.Trainer(accelerator="gpu", devices=1)


def run():
    tune.run(
        train,
        resources_per_trial={"cpu": 8, "gpu": 1},
        log_to_file=["stdout.txt", "stderr.txt"], # For some reason removing this line makes the script work
        config={},
        num_samples=1,
        name="Test",
    )


if __name__ == "__main__":
    run()

Submitted to a ray cluster with

ray job submit --runtime-env-json='{"working_dir": "./"}' -- python test.py

Issue Severity

Medium: It is a significant difficulty but I can work around it.

Metadata

Metadata

Assignees

Labels

P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn'tcoreIssues that should be addressed in Ray CoretuneTune-related issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions