Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FileNotFoundError for best checkpoint when using DDP with Hydra #5512

Closed
azouaoui-cv opened this issue Jan 14, 2021 · 16 comments · Fixed by #5629
Closed

FileNotFoundError for best checkpoint when using DDP with Hydra #5512

azouaoui-cv opened this issue Jan 14, 2021 · 16 comments · Fixed by #5629
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on

Comments

@azouaoui-cv
Copy link

azouaoui-cv commented Jan 14, 2021

🐛 Bug

I am getting a FileNotFoundError for loading the best checkpoint when using trainer.test() after trainer.fit() in DDP mode with Hydra.

My configuration file specifies that hydra.run.dir="/path/to/data/${now:%Y-%m-%d_%H-%M-%S}".
As a result, the first process (rank 0) spawns in "/path/to/data/datetime1" and creates the "ckpts" and "logs" folders there while the second process (rank 1) spawns in "/path/to/data/datetime2" and cannot access the "ckpts" and "logs" folders.
It appears that when calling trainer.test(), the program looks for "/path/to/data/datetime2/ckpts/best.ckpt" which is indeed not there.

Here is the error stack:

Epoch 4: val_acc reached 30.00000 (best 30.00000), saving model to /home/azouaoui/github/PL-Hydra-template/data/runs/2021-01-14_08-03-33/ckpts/epoch=004-val_acc=30.000.ckpt as top 1
[lightning][INFO] - Epoch 4: val_acc reached 30.00000 (best 30.00000), saving model to /home/azouaoui/github/PL-Hydra-template/data/runs/2021-01-14_08-03-33/ckpts/epoch=004-val_acc=30.000.ckpt as top 1
Epoch 4: 100%|███████████| 4/4 [00:00<00:00, 11.70it/s, loss=2.729, v_num=0, val_acc=30, best_val_acc=30Saving latest checkpoint...                                                                               
[lightning][INFO] - Saving latest checkpoint...
Epoch 4: 100%|███████████| 4/4 [00:00<00:00, 11.62it/s, loss=2.729, v_num=0, val_acc=30, best_val_acc=30]
[__main__][CRITICAL] - [Errno 2] No such file or directory: '/home/azouaoui/github/PL-Hydra-template/data/runs/2021-01-14_08-03-35/ckpts/epoch=004-val_acc=30.000.ckpt'
Traceback (most recent call last):
  File "/home/azouaoui/github/PL-Hydra-template/train.py", line 41, in main
    train(cfg)
  File "/home/azouaoui/github/PL-Hydra-template/train.py", line 34, in train
    logger.info(trainer.test())
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 720, in test
    results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 750, in __test_using_best_weights
    ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/pytorch_lightning/utilities/cloud_io.py", line 31, in load
    with fs.open(path_or_url, "rb") as f:
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/fsspec/spec.py", line 936, in open
    **kwargs
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/fsspec/implementations/local.py", line 117, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/fsspec/implementations/local.py", line 199, in __init__
    self._open()
  File "/scratch/artemis/azouaoui/miniconda3/envs/jz/lib/python3.7/site-packages/fsspec/implementations/local.py", line 204, in _open
    self.f = open(self.path, mode=self.mode)
FileNotFoundError: [Errno 2] No such file or directory: '/home/azouaoui/github/PL-Hydra-template/data/runs/2021-01-14_08-03-35/ckpts/epoch=004-val_acc=30.000.ckpt'

Please reproduce using the BoringModel

Error is triggered by using DDP with at least 2 GPUs. Hence I cannot use Colab.

To Reproduce

Use this repository

Have at least 2 GPUs available.

$ git clone https://github.com/inzouzouwetrust/PL-Hydra-DDP-bug
$ cd PL-Hydra-DDP-bug && pip install -r requirements.txt
$ python bug_report_model.py

Expected behavior

I would expect the program to use the subfolder spawned by the first process (rank 0) when loading the best checkpoint.

Environment

* CUDA:
        - GPU:
                - GeForce GTX TITAN X
                - GeForce GTX TITAN X
        - available:         True
        - version:           10.2
* Packages:
        - numpy:             1.19.4
        - pyTorch_debug:     True
        - pyTorch_version:   1.7.0
        - pytorch-lightning: 1.0.5
        - tqdm:              4.54.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.7.9
        - version:           #219-Ubuntu SMP Tue Aug 11 12:26:50 UTC 2020

Additional context

  • For further details, please take a look at my recent chat with Hydra main author on Zulip.
  • Take a look at this PL forums topic.
@azouaoui-cv azouaoui-cv added bug Something isn't working help wanted Open to be worked on labels Jan 14, 2021
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@SeanNaren SeanNaren added the distributed Generic distributed-related topic label Jan 14, 2021
@SeanNaren
Copy link
Contributor

SeanNaren commented Jan 14, 2021

Hey @inzouzouwetrust could you reproduce via the bug_report_model I shared with you and paste it here? Will help me debug

EDIT: it's in the main issue, missed it... I assume it needs the config's to be specified in the repo as well!

@awaelchli
Copy link
Contributor

awaelchli commented Jan 14, 2021

Only had a quick glance at this issue, but could simply by my fix here: #5155
It could simply be unrelated to hydra

@romesco
Copy link
Contributor

romesco commented Jan 14, 2021

I have also run into the issue @awaelchli mentioned independent of hydra launching. Thanks for the fix! I'll try pulling it down to see if it makes a difference in this context 🙂 .

@SeanNaren
Copy link
Contributor

I tried @awaelchli fix but it doesn't seem to work. Looking at what happens this is still related to the output run directory of Hydra (which wraps our output dir).

I had to make a few modifications for the bug to appear @inzouzouwetrust:

hydra:
  run:
    dir: "data/${now:%Y-%m-%d_%H-%M-%S}"
  sweep:
    dir: "data/${now:%Y-%m-%d_%H-%M-%S}"
    subdir: ${hydra.job.num}

checkpoint:
  monitor: "x"
  mode: "max"
  verbose: True
  save_top_k: 1
import os

import hydra
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset


class RandomDataset(Dataset):
    """
    >>> RandomDataset(size=10, length=20)  # doctest: +ELLIPSIS
    <...bug_report_model.RandomDataset object at ...>
    """

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    """
    >>> BoringModel()  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    BoringModel(
      (layer): Linear(...)
    )
    """

    def __init__(self):
        """
        Testing PL Module

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None

        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


#  NOTE: If you are using a cmd line to run your script,
#  provide the cmd line as below.
#  opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
#  parser = ArgumentParser()
#  args = parser.parse_args(opt)


@hydra.main(config_path="", config_name="config")
def test_run(cfg):
    print(cfg)
    class TestModel(BoringModel):
        def on_train_epoch_start(self) -> None:
            print("override any method to prove your bug")

    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    model = TestModel()
    callbacks = [ModelCheckpoint(**cfg.checkpoint)]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        weights_summary=None,
        callbacks=callbacks,
        accelerator="ddp",
        gpus=2,
        deterministic=True,
        benchmark=False,
    )
    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)


if __name__ == "__main__":
    test_run()

The issue comes from that when internally in DDP we spin up multiple processes, Hydra is creating new output run directories that wrap our output directory. We have some additional hydra logic before we spin up processes, but nothing handles this:

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/ddp_accelerator.py#L138-L141

Some potential solutions:

  • Force users to have to specify their hydra run directory to a static file path. This is not preferred as Hydra integration is hindered and requires more manual work imo
  • Before spinning up the other processes, we need to pass the parent hydra run dir to the child processes. If this could be done via an environment variable, this would be favourable I think, cc @omry

@tchaton
Copy link
Contributor

tchaton commented Jan 15, 2021

Hey @romesco,

I experienced this bug before. Here is my hacky solution. Removed -%S, so you can keep training.

hydra:
  run:
    dir: "data/${now:%Y-%m-%d_%H-%M}"
  sweep:
    dir: "data/${now:%Y-%m-%d_%H-%M}"
    subdir: ${hydra.job.num}

checkpoint:
  monitor: "x"
  mode: "max"
  verbose: True
  save_top_k: 1

from @SeanNaren comment, If this could be done via an environment variable, this would be favourable I think. Yes, definitely. Each child process should also clean their own auto-generated path by hydra and reset their os.getcwd to new path.

Best,
T.C

@omry
Copy link
Contributor

omry commented Jan 15, 2021

Before spinning up the other processes, we need to pass the parent hydra run dir to the child processes. If this could be done via an environment variable, this would be favourable I think, cc @omry

Hydra in general does not take environment variables to configure user options.
A user can configure their own config to achieve that, for example, the following would allow use HYDRA_RUN_DIR environment variable if present or a_default_value otherwise.

hydra:
  run:
    dir: ${env:HYDRA_RUN_DIR,a_default_value}

Maybe a better option is to allow users to pass additional command line arguments to the ddp process.
Unfortunately it does not seem like there is a mechanism to configure the ddp acceperators right now (from a glance at the repro code), but appending hydra.run.dir=os.getcwd() to the command line being passed to each of the workers should likely solve the problem.

Each child process should also clean their own auto-generated path by hydra and reset their os.getcwd to new path.

The idea with overriding the hydra.run.dir to the proper one is that such cleanup or chdir would not be necessary.

@romesco
Copy link
Contributor

romesco commented Jan 16, 2021

Alright, applying @awaelchli's fix #5155 along with the configuration found here:
https://github.com/romesco/hydra-lightning/blob/examples/ddp/examples/hydra_pl_bug_report_template.py

is working for me. This is with:

hydra-core==1.1.0.dev3
torch==1.7.1
git+https://github.com/pytorch/hydra-torch
git+https://github.com/romesco/hydra-lightning/#subdirectory=hydra-configs-lightning
pytorch-lightning@bugfix/ddp-ckpt

@Borda Borda added the 3rd party Related to a 3rd-party label Jan 17, 2021
@SeanNaren
Copy link
Contributor

hey @romesco I think it works because your run dir is set to:

run_dir = "${now:%Y-%m-%d}"

if you were to add a finer granularity, it would probably break (add -%s). This is because internally processes are spun up and create their own run dirs and interpolate the string at run time (so they are all out of sync by a few seconds)

I don't like having to force users to remove %s from their run dir because we should strive to have Hydra work with little friction. @omry I don't think manually appending hydra.run.dir=os.getcwd() would fix the issue as this would override the user's specified run directory right?

I think in the perfect world, we'd manually pass the interpolated run directory from the main process to the child processes, but this might be more complicated...

@omry
Copy link
Contributor

omry commented Jan 22, 2021

I don't like having to force users to remove %s from their run dir because we should strive to have Hydra work with little friction. @omry I don't think manually appending hydra.run.dir=os.getcwd() would fix the issue as this would override the user's specified run directory right?

I am proposing to do it when spawning the ddp.
basically, at that point - os.getcwd() should point to the actual output directory generated by Hydra.
I will probably provide an API to access the generated output directory in the future, but it's not there yet and os.getcwd() is pretty close.

  • User start app
  • Hydra generates output dir and chdir
  • User function is running with the output directory as the cwd
  • User is calling PL to spawn DDP, cwd is still in the working dir.
  • PL append hydra.run.dir=os.getcwd() to child processes, ensuring that they share the same working directory as the parent process.

Side issues:
You can also disable the .hydra directroy and the logging configuration for the child processes.
See facebookresearch/hydra#910 for workarounds:

@SeanNaren
Copy link
Contributor

nice! apologies just lack of my own understanding, testing locally and it seems to work fine (not sure what's going to happen in the logs with multiple processes writing to it, but atleast 1 folder containing everything).

Will make a PR!

@SeanNaren
Copy link
Contributor

@romesco @inzouzouwetrust could you check out #5629 and see if this fixes it for you guys?

@azouaoui-cv
Copy link
Author

I'm late to the party but I can confirm that #5629 fixes it for me :)
Thank you so much!

@Vozf
Copy link
Contributor

Vozf commented Nov 23, 2021

Well I'm getting this error now with pl 1.5.2 and hydra 1.0. It can't actually find the train script and freeezes

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
.../.venv/bin/python: can't open file '.../train': [Errno 2] No such file or directory
Global seed set to 42
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2

@Vozf
Copy link
Contributor

Vozf commented Nov 23, 2021

For me error was related to poetry as training runs with command

poetry run train

and replacing it with

poetry run train.py

fixes the error

@emanuelevivoli
Copy link

I have the same error.
I'm using DDP on two GPUs server and I get the following error:

Traceback (most recent call last):                                                                                                                                                                                                                                                                                                   
  File "/equilibrium/evivoli/asmara/src/models/train_model.py", line 125, in <module>                                                                                                                                                                                                                                                
    train()                                                                                                                                                                                                                                                                                                                          
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/main.py", line 94, in decorated_main                                                                                                                                                                                                                 
    _run_hydra(                                                                                                                                                                                                                                                                                                                      
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra                                                                                                                                                                                                         
    _run_app(                                                                                                                                                                                                                                                                                                                        
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/utils.py", line 457, in _run_app                                                                                                                                                                                                           
    run_and_report(                                                                                                                                                                                                                                                                                                                  
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/utils.py", line 223, in run_and_report                                                                                                                                                                                                     
    raise ex                                                                                                                                                                                                                                                                                                                         
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/utils.py", line 220, in run_and_report                                                                                                                                                                                                     
    return func()
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/equilibrium/evivoli/asmara/src/models/train_model.py", line 119, in train
    model = model.load_from_checkpoint(best_checkpoint_path)
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 139, in load_from_checkpoint
    return _load_from_checkpoint(
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/pytorch_lightning/core/saving.py", line 160, in _load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=map_location)
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py", line 47, in _load
    with fs.open(path_or_url, "rb") as f:
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/fsspec/spec.py", line 1135, in open
    f = self._open(
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/fsspec/implementations/local.py", line 183, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/fsspec/implementations/local.py", line 285, in __init__
    self._open()
  File "/home/evivoli/miniconda3/envs/new-env/lib/python3.9/site-packages/fsspec/implementations/local.py", line 290, in _open
    self.f = open(self.path, mode=self.mode)
FileNotFoundError: [Errno 2] No such file or directory: '/equilibrium/evivoli/asmara/.checkpoints/unet/multi/0-holograms/epoch=00-val_loss=2.66-val_acc=0.08.ckpt'

It seems that the file does not exists when I'm trying to load it, and in fact it is created just after the program crashed.

I don't know if it is due to Hydra or not .... However, I'm debugging the code (the problematic part is the following):

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=f'{BASEPATH}/.checkpoints/{cfg.model.name}/{cfg.data.task}/{cfg.seed}-{cfg.data.dataset}/',
    filename='{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}',
    save_top_k=3,
    mode='min'
)

# popolate the config trainer with configurations
trainer = pl.Trainer(
    **cfg.trainer, 
    # when strategy:'ddp' and find_unused_parameters:False, 
    strategy= DDPStrategy(find_unused_parameters=False) 
        if cfg.custom_trainer.strategy == 'ddp' and 
            cfg.custom_trainer.find_unused_parameters == False 
        else 'ddp',
    logger = wandb_logger,
    callbacks=[early_stop_callback, checkpoint_callback],
)

trainer.fit(model, train_loader, val_loader)

# Load the best checkpoint
best_checkpoint_path = trainer.checkpoint_callback.best_model_path
print(f"Loading best checkpoint from {best_checkpoint_path}")
model = model.load_from_checkpoint(best_checkpoint_path)

When I'm debugging, when one subprocess reaches the _load function inside cloud_io.py, the file doesn't exist yet as another process is in the function for writing it. How can I wait for the checkpoint to be written before trying to load the weights?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants