You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My custom subclass of LightningModule accepts a tensor hyperparameter in it's __init__ method (context: I am doing multi-task learning, and I pass a tensor of weights for combining per-task losses).
I use CSVLogger, and in my module's __init__ I call self.save_hyperparameters(logger=True). As a result, the hyperparameters are saved to hparams.yaml, and in particular this weight tensor is serialized like this:
The module's load_from_checkpoint method (which I didn't override) can't load these weights from hparams.yaml if I explicitly specify hparams_file=hparams.yaml with the following exception:
ConstructorError: could not determine a constructor for the tag 'tag:yaml.org,2002:python/object/apply:torch._utils._rebuild_tensor_v2'
in "hparams.yaml", line 4, column 20
I have found out that replacing yaml.full_load with yaml.unsafe_load in the code of lightning.pytorch.core.saving.load_hparams_from_yaml
fixes this issue. I am experiencing this issue in v2.1.3, and I haven't tried updating to the latest version, but since the issue is caused by using safe yaml deserialization, I believe it should persist in the latest version.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
No response
Environment
No response
More info
No response
The text was updated successfully, but these errors were encountered:
Using the hparams file to load the checkpoint is a bit of a special use case from the past that I don't recommend to use in general. Since you are already using save_hyperparameters(), Lightning has saved your hyperparameters into the checkpoint. So you don't need to pass .load_from_checkpoint(hparams_file=...). Just drop it and it will work fine. Can you try?
As for the "unsafe" loading, thanks for investigating this. It was purposely chosen this way to use safe loading: #11099. I think we should keep it this way.
Bug description
My custom subclass of
LightningModule
accepts a tensor hyperparameter in it's__init__
method (context: I am doing multi-task learning, and I pass a tensor of weights for combining per-task losses).I use CSVLogger, and in my module's
__init__
I callself.save_hyperparameters(logger=True)
. As a result, the hyperparameters are saved tohparams.yaml
, and in particular this weight tensor is serialized like this:The module's
load_from_checkpoint
method (which I didn't override) can't load these weights fromhparams.yaml
if I explicitly specifyhparams_file=hparams.yaml
with the following exception:I have found out that replacing
yaml.full_load
withyaml.unsafe_load
in the code oflightning.pytorch.core.saving.load_hparams_from_yaml
pytorch-lightning/src/lightning/pytorch/core/saving.py
Line 307 in 48c39ce
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
No response
Environment
No response
More info
No response
The text was updated successfully, but these errors were encountered: