Skip to content

Commit

Permalink
removed checkpoint from prediction (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader authored Feb 27, 2022
1 parent cbe49bf commit 0d267e0
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,13 +1223,8 @@ def predict_from_dataset(
trainer = trainer if trainer is not None else self.trainer
self._setup_trainer(trainer=trainer, verbose=verbose, epochs=self.n_epochs)

# if model checkpoint was loaded without calling fit afterwards (when `load_ckpt_path is not None`),
# trainer needs to be instantiated here
ckpt_path = self.load_ckpt_path
self.load_ckpt_path = None

# prediction output comes as nested list: list of predicted `TimeSeries` for each batch.
predictions = self.trainer.predict(self.model, pred_loader, ckpt_path=ckpt_path)
predictions = self.trainer.predict(self.model, pred_loader)
# flatten and return
return [ts for batch in predictions for ts in batch]

Expand Down Expand Up @@ -1316,6 +1311,16 @@ def save_model(self, path: str) -> None:
def load_model(path: str) -> "TorchForecastingModel":
"""loads a model from a given file path. The file name should end with '.pth.tar'
Example for loading a :class:`RNNModel`:
.. highlight:: python
.. code-block:: python
from darts.models import RNNModel
model_loaded = RNNModel.load_model("my_model.pth.tar")
..
Parameters
----------
path
Expand All @@ -1339,8 +1344,20 @@ def load_from_checkpoint(
"""
Load the model from automatically saved checkpoints under '{work_dir}/darts_logs/{model_name}/checkpoints/'.
This method is used for models that were created with ``save_checkpoints=True``.
If you manually saved your model, consider using :meth:`load_model() <TorchForeCastingModel.load_model()>`.
Example for loading a :class:`RNNModel` from checkpoint (``model_name`` is the ``model_name`` used at model
creation):
.. highlight:: python
.. code-block:: python
from darts.models import RNNModel
model_loaded = RNNModel.load_from_checkpoint(model_name, best=True)
..
If ``file_name`` is given, returns the model saved under
'{work_dir}/darts_logs/{model_name}/checkpoints/{file_name}'.
Expand Down Expand Up @@ -1387,7 +1404,7 @@ def load_from_checkpoint(
file_name = _get_checkpoint_fname(work_dir, model_name, best=best)

file_path = os.path.join(checkpoint_dir, file_name)
logger.info("loading {}".format(file_name))
logger.info(f"loading {file_name}")

model.model = model.model.__class__.load_from_checkpoint(file_path)
model.load_ckpt_path = file_path
Expand Down

0 comments on commit 0d267e0

Please sign in to comment.