Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ValueError: 'cuda' is not a valid DistributedType #802

Closed
gsamaras opened this issue Feb 16, 2022 · 7 comments
Closed

[BUG] ValueError: 'cuda' is not a valid DistributedType #802

gsamaras opened this issue Feb 16, 2022 · 7 comments
Labels
bug Something isn't working triage Issue waiting for triaging

Comments

@gsamaras
Copy link
Contributor

gsamaras commented Feb 16, 2022

Describe the bug
Suddenly, upon relaunching my notebook I wasn't able to train N-Beats on GPU and got ValueError: 'cuda' is not a valid DistributedType, without me changing anything in the code.

To Reproduce
Install like this in a Jupyter Notebook:

!pip install 'u8darts[torch]

and then try to train any model in GPU, e.g. an N-Beats model like this:

from darts.models import NBEATSModel

model_nbeats = NBEATSModel(
    input_chunk_length=2,
    output_chunk_length=1,
    generic_architecture=True,
    num_stacks=2,
    num_blocks=1,
    num_layers=1,
    layer_widths=2,
    n_epochs=20,
    nr_epochs_val_period=1,
    batch_size=2,
    random_state=0,
    optimizer_cls = optim.Adam,
    optimizer_kwargs={"lr": 1e-3},
    lr_scheduler_cls = optim.lr_scheduler.ReduceLROnPlateau,
    lr_scheduler_kwargs= {"optimizer": optim.Adam, "threshold": 0.0001, "verbose": True},
    torch_device_str="cuda:0",
)

which gives the error:

ValueError                                Traceback (most recent call last)
Input In [6], in <module>
    167 train, val = series.split_after(trainset_size)
    169 model_nbeats = define_NBEATS_model(train_set=train, val_set=val, gridsearch=False)
--> 170 model_nbeats.fit(series=train, val_series=val, verbose=True)
    172 pred_series = model_nbeats.historical_forecasts(
    173     series,
    174     start=trainset_size,
   (...)
    178     verbose=True,
    179 )
    180 display_forecast(pred_series, series, "1 horizon", start_date=trainset_size)

File /opt/conda/lib/python3.9/site-packages/darts/utils/torch.py:70, in random_method.<locals>.decorator(self, *args, **kwargs)
     68 with fork_rng():
     69     manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 70     return decorated(self, *args, **kwargs)

File /opt/conda/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py:726, in TorchForecastingModel.fit(self, series, past_covariates, future_covariates, val_series, val_past_covariates, val_future_covariates, trainer, verbose, epochs, max_samples_per_ts, num_loader_workers)
    722     val_dataset = None
    724 logger.info(f"Train dataset contains {len(train_dataset)} samples.")
--> 726 return self.fit_from_dataset(
    727     train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers
    728 )

File /opt/conda/lib/python3.9/site-packages/darts/utils/torch.py:70, in random_method.<locals>.decorator(self, *args, **kwargs)
     68 with fork_rng():
     69     manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
---> 70     return decorated(self, *args, **kwargs)

File /opt/conda/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py:864, in TorchForecastingModel.fit_from_dataset(self, train_dataset, val_dataset, trainer, verbose, epochs, num_loader_workers)
    861 verbose = True if verbose is None else verbose
    863 # setup trainer
--> 864 self._setup_trainer(trainer, verbose, train_num_epochs)
    866 # TODO: multiple training without loading from checkpoint is not trivial (I believe PyTorch-Lightning is still
    867 #  working on that, see https://github.com/PyTorchLightning/pytorch-lightning/issues/9636)
    868 if self.epochs_trained > 0 and not self.load_ckpt_path:

File /opt/conda/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py:486, in TorchForecastingModel._setup_trainer(self, trainer, verbose, epochs)
    480 self.trainer_params["enable_model_summary"] = (
    481     verbose if self.model.epochs_trained == 0 else False
    482 )
    483 self.trainer_params["enable_progress_bar"] = verbose
    485 self.trainer = (
--> 486     self._init_trainer(trainer_params=self.trainer_params, max_epochs=epochs)
    487     if trainer is None
    488     else trainer
    489 )

File /opt/conda/lib/python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py:500, in TorchForecastingModel._init_trainer(trainer_params, max_epochs)
    497 if max_epochs is not None:
    498     trainer_params_copy["max_epochs"] = max_epochs
--> 500 return pl.Trainer(**trainer_params_copy)

File /opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/env_vars_connector.py:38, in _defaults_from_env_vars.<locals>.insert_env_defaults(self, *args, **kwargs)
     35 kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
     37 # all args were already moved to kwargs
---> 38 return fn(self, **kwargs)

File /opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:431, in Trainer.__init__(self, logger, checkpoint_callback, enable_checkpointing, callbacks, default_root_dir, gradient_clip_val, gradient_clip_algorithm, process_position, num_nodes, num_processes, devices, gpus, auto_select_gpus, tpu_cores, ipus, log_gpu_memory, progress_bar_refresh_rate, enable_progress_bar, overfit_batches, track_grad_norm, check_val_every_n_epoch, fast_dev_run, accumulate_grad_batches, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, val_check_interval, flush_logs_every_n_steps, log_every_n_steps, accelerator, strategy, sync_batchnorm, precision, enable_model_summary, weights_summary, weights_save_path, num_sanity_val_steps, resume_from_checkpoint, profiler, benchmark, deterministic, reload_dataloaders_every_n_epochs, reload_dataloaders_every_epoch, auto_lr_find, replace_sampler_ddp, detect_anomaly, auto_scale_batch_size, prepare_data_per_node, plugins, amp_backend, amp_level, move_metrics_to_cpu, multiple_trainloader_mode, stochastic_weight_avg, terminate_on_nan)
    428 # init connectors
    429 self._data_connector = DataConnector(self, multiple_trainloader_mode)
--> 431 self._accelerator_connector = AcceleratorConnector(
    432     num_processes,
    433     devices,
    434     tpu_cores,
    435     ipus,
    436     accelerator,
    437     strategy,
    438     gpus,
    439     gpu_ids,
    440     num_nodes,
    441     sync_batchnorm,
    442     benchmark,
    443     replace_sampler_ddp,
    444     deterministic,
    445     precision,
    446     amp_backend,
    447     amp_level,
    448     plugins,
    449 )
    450 self.logger_connector = LoggerConnector(self, log_gpu_memory)
    451 self._callback_connector = CallbackConnector(self)

File /opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:166, in AcceleratorConnector.__init__(self, num_processes, devices, tpu_cores, ipus, accelerator, strategy, gpus, gpu_ids, num_nodes, sync_batchnorm, benchmark, replace_sampler_ddp, deterministic, precision, amp_type, amp_level, plugins)
    164     self._set_training_type_plugin()
    165 else:
--> 166     self.set_distributed_mode()
    168 self.handle_given_plugins()
    169 self._set_distrib_type_if_training_type_plugin_passed()

File /opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:882, in AcceleratorConnector.set_distributed_mode(self, strategy)
    880     self._device_type = DeviceType.IPU
    881 elif self.distributed_backend and self._distrib_type is None:
--> 882     self._distrib_type = DistributedType(self.distributed_backend)
    884 if self.num_gpus > 0 and not _use_cpu:
    885     self._device_type = DeviceType.GPU

File /opt/conda/lib/python3.9/enum.py:384, in EnumMeta.__call__(cls, value, names, module, qualname, type, start)
    359 """
    360 Either returns an existing member, or creates a new enum class.
    361 
   (...)
    381 `type`, if set, will be mixed in as the first base class.
    382 """
    383 if names is None:  # simple value lookup
--> 384     return cls.__new__(cls, value)
    385 # otherwise, functional API: we're creating a new Enum type
    386 return cls._create_(
    387         value,
    388         names,
   (...)
    392         start=start,
    393         )

File /opt/conda/lib/python3.9/enum.py:702, in Enum.__new__(cls, value)
    700 ve_exc = ValueError("%r is not a valid %s" % (value, cls.__qualname__))
    701 if result is None and exc is None:
--> 702     raise ve_exc
    703 elif exc is None:
    704     exc = TypeError(
    705             'error in %s._missing_: returned %r instead of None or a valid member'
    706             % (cls.__name__, result)
    707             )

ValueError: 'cuda' is not a valid DistributedType

My instance has a GPU:

import torch
torch.cuda.is_available()
torch.cuda.get_device_name(0) # Outputs'Quadro M4000'

Expected behavior
Training in GPU should be possible.

System (please complete the following information):

  • Python version: 3.9.7
  • darts version: 0.17.0
  • pytorch-lightning: 1.5.10
  • torch: 1.10.2

Could it be that something with the dependency on Torch is happening?

Additional context
Related: #801

@gsamaras gsamaras added bug Something isn't working triage Issue waiting for triaging labels Feb 16, 2022
@dennisbader
Copy link
Collaborator

dennisbader commented Feb 16, 2022

We released darts version 0.17.0 yesterday.
Our TorchForecastingModel are now built on top of PyTorch Lightning.

Passing torch_device_str should have raised a DeprecationWarning.
The device should now be set through pl_trainer_kwargs (a dict of PyTorch Lightning Trainer parameters, see here) at model creation.

Can try it instead with below and let us know if it works?
Also for further infromation about setting the device, see:

model_nbeats = NBEATSModel(
    input_chunk_length=2,
    output_chunk_length=1,
    generic_architecture=True,
    num_stacks=2,
    num_blocks=1,
    num_layers=1,
    layer_widths=2,
    n_epochs=20,
    nr_epochs_val_period=1,
    batch_size=2,
    random_state=0,
    optimizer_cls=optim.Adam,
    optimizer_kwargs={"lr": 1e-3},
    lr_scheduler_cls=optim.lr_scheduler.ReduceLROnPlateau,
    lr_scheduler_kwarg={"optimizer": optim.Adam, "threshold": 0.0001, "verbose": True},
    # torch_device_str="cuda:0",
    pl_trainer_kwargs={
        "accelerator": "gpu",
        "gpus": [0]
    }
)

@gsamaras
Copy link
Contributor Author

gsamaras commented Feb 16, 2022

@dennisbader thanks for the prompt reply. This did the trick in getting the GPU used, but I guess because of PyTorch Lightning something else also broke and I now get this error:

MisconfigurationException: `configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used. For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}

I tried passing "monitor": "val_loss" in the optimizer's kwargs or in the lr scheduler's kwargs, but that didn't solve the issue. Any idea?

@dennisbader
Copy link
Collaborator

Hey @gsamaras and thanks for that.

This is indeed a bug and happens when using ReduceLROnPlateau -> Lightning-AI/pytorch-lightning#4454

We will fix this soon. For now, you can either use the model without ReduceLROnPlateau or downgrade darts to version 0.16.1.

@dennisbader
Copy link
Collaborator

No that's fine, I can do it.

Thanks again!

@dennisbader
Copy link
Collaborator

Darts 0.17.1 was released, which fixes both the torch_device_str issue and the ReduceLROnPlateau bug.

@gsamaras
Copy link
Contributor Author

gsamaras commented Feb 17, 2022

@dennisbader indeed I was able to have this working. I also checked that the documentation was updated, thanks!

May I ask if I'll be able to simply use a TPU like:

pl_trainer_kwargs={
    "accelerator": "tpu",
    "tpus": [0]
}

or it's something that darts won't seamlessly handle (like in the GPU case)? I don't know if TPUs can work with local data (which do not live in the Google cloud to be honest).


PS: As a side note: After upgrading darts to 0.17.1, historical_forecasts() take a significant amount of time (40 minutes for < 3.500 data points), while with darts 0.16.1 that would take just a few minutes. I'll investigate further though and open a new issue if needed.

@hrzn
Copy link
Contributor

hrzn commented Feb 18, 2022

@dennisbader indeed I was able to have this working. I also checked that the documentation was updated, thanks!

May I ask if I'll be able to simply use a TPU like:

pl_trainer_kwargs={
    "accelerator": "tpu",
    "tpus": [0]
}

I think that should work, but if you are on Colab, PyTorch lightning (which Darts relies on), requires taking an extra step to make TPUs work: https://pytorch-lightning.readthedocs.io/en/stable/advanced/tpu.html#colab-tpus

I'll close this issue for now as the GPU issue is solved. Don't hesitate to open a new one if you spot other issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triage Issue waiting for triaging
Projects
None yet
Development

No branches or pull requests

3 participants