Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_checkpoint with hparams_file pointing to .yaml can't load tensor hyperparameter #19559

Closed
sergey-protserov-uhn opened this issue Mar 2, 2024 · 2 comments
Labels
bug Something isn't working ver: 2.1.x

Comments

@sergey-protserov-uhn
Copy link

Bug description

My custom subclass of LightningModule accepts a tensor hyperparameter in it's __init__ method (context: I am doing multi-task learning, and I pass a tensor of weights for combining per-task losses).

I use CSVLogger, and in my module's __init__ I call self.save_hyperparameters(logger=True). As a result, the hyperparameters are saved to hparams.yaml, and in particular this weight tensor is serialized like this:

task_loss_weights: !!python/object/apply:torch._utils._rebuild_tensor_v2
- !!python/object/apply:torch.storage._load_from_bytes
  - !!binary |
    gAKKCmz8nEb5IGqoUBkugAJN6QMugAJ9cQAoWBAAAABwcm90b2NvbF92ZXJzaW9ucQFN6QNYDQAA
    AGxpdHRsZV9lbmRpYW5xAohYCgAAAHR5cGVfc2l6ZXNxA31xBChYBQAAAHNob3J0cQVLAlgDAAAA
    aW50cQZLBFgEAAAAbG9uZ3EHSwR1dS6AAihYBwAAAHN0b3JhZ2VxAGN0b3JjaApGbG9hdFN0b3Jh
    Z2UKcQFYDgAAADk0MDk2NTkwNTM0MzM2cQJYAwAAAGNwdXEDSwNOdHEEUS6AAl1xAFgOAAAAOTQw
    OTY1OTA1MzQzMzZxAWEuAwAAAAAAAACrqqo+q6qqPquqqj4=
- 0
- !!python/tuple
  - 3
- !!python/tuple
  - 1
- false
- !!python/object/apply:collections.OrderedDict
  - []

The module's load_from_checkpoint method (which I didn't override) can't load these weights from hparams.yaml if I explicitly specify hparams_file=hparams.yaml with the following exception:

ConstructorError: could not determine a constructor for the tag 'tag:yaml.org,2002:python/object/apply:torch._utils._rebuild_tensor_v2'
  in "hparams.yaml", line 4, column 20

I have found out that replacing yaml.full_load with yaml.unsafe_load in the code of lightning.pytorch.core.saving.load_hparams_from_yaml

hparams = yaml.full_load(fp)
fixes this issue. I am experiencing this issue in v2.1.3, and I haven't tried updating to the latest version, but since the issue is caused by using safe yaml deserialization, I believe it should persist in the latest version.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

No response

Environment

No response

More info

No response

@sergey-protserov-uhn sergey-protserov-uhn added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Mar 2, 2024
@awaelchli
Copy link
Contributor

Hey @sergey-protserov-uhn

Using the hparams file to load the checkpoint is a bit of a special use case from the past that I don't recommend to use in general. Since you are already using save_hyperparameters(), Lightning has saved your hyperparameters into the checkpoint. So you don't need to pass .load_from_checkpoint(hparams_file=...). Just drop it and it will work fine. Can you try?

As for the "unsafe" loading, thanks for investigating this. It was purposely chosen this way to use safe loading: #11099. I think we should keep it this way.

@awaelchli awaelchli removed the needs triage Waiting to be triaged by maintainers label Mar 3, 2024
@sergey-protserov-uhn
Copy link
Author

Thank you for your answer!

Yes, it works perfectly without specifying hparams_file, thank you for the hint!

Since using this argument is in general discouraged and everything I need works without it anyway, I will close this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants