Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates save/load behavior for breaking change in torch 2.6 #317

Open
wants to merge 37 commits into
base: main
Choose a base branch
from

Conversation

billbrod
Copy link
Member

@billbrod billbrod commented Feb 27, 2025

As discussed in the related issue (#313), Torch 2.6 introduced a breaking change in how load works, by setting weights_only=True by default and being much more conservative. This PR makes plenoptic compatible with that change.

Our previous behavior was incompatible with the changes because we were saving two types of objects:

  • python functions (metamer's loss function, mad competition's metrics)
  • pytorch optimization objects (optimizers and schedulers)

We also save a bunch of python primitives (floats, ints, strings) and tensors, but these are fine.

Now, we have modified save() so that object attributes are placed into three categories:

  • save_attrs: primitives and tensors that can be saved directly. These are not expicitly set, but they're all attributes that aren't included in the next two categories.
  • save_io_attrs: functions or callable objects that accept and return tensors (e.g., loss functions, metrics, models). these are state-less, in that nothing changes over the course of synthesis. For these objects, we save a tuple with their name (using _get_name in synthesis.py), the name of one or more other attributes of the object that can be passed as inputs (e.g., _image, _metamer) and the output when called on those objects.
  • save_state_dict_attrs: pytorch objects with a state_dict that may change over the course of synthesis (e.g., optimizers whose learning rate may change). We save a tuple with their name and their state_dict.

On load, we:

  • (as before) check all attributes set at initialization match those we are trying to load (e.g., range_penalty_lambda, image).
  • check all "io attributes" from save, ensuring that their names and input/output behavior are the same.
  • check all the "save_dict attributes" have the same name and then load their state_dict (overwriting existing state). Currently, these attributes only include schedulers and optimizers, which are not set at initialization. We cache their name and state_dict, and check when they are initialized (the first call to synthesize).

For all synthesis objects other than Eigendistortion (for which it's not relevant), we also provide the save_objects and weights_only boolean flags for save and load, respectively. If save_objects=True, we save the callables (loss functions, metrics, schedulers, optimizers, but NOT models, since torch.nn.Module can be very large). In this case, they must then pass weights_only=False to load, to override the new default torch behavior. EDIT: removed these flags, see later comment.

This makes our save/load behavior not backwards compatible because of how we handle the model. Previously, I was implicitly checking it (by checking for e.g., Metamer.target_representation), but now I treat it the same as the loss function. However, this current behavior works with (at least) torch 2.5 and 2.6, so that's good.

Additionally:

  • we add some checks for scheduler to match the checks made by optimizer: first time synthesize is called, scheduler can be non-None, every subsequent time, must be None.
  • geodesic update: in our docs, we say that plenoptic works with models that output 3d or 4d tensors, but the old way of doing this meant that geodesics only worked with 4d outputs. adds a small fix for that (and tests)
  • I also removed the ruff actions from ci.yml, because they're being handled by the pre-commit action (and the versions were out of sync, which meant the ci.yml version was failing, preventing me from merging this PR)
  • Had to pin sphinx<8.2 due to a Cannot run with Sphinx 8.2.0 spatialaudio/nbsphinx#825. I will be dropping nbsphinx when I switch over to myst/mystnb, so this is temporary.

I still need to:

  • Add tests for new behavior.
  • Add a user-facing docs page with some details about save/load. -- after removing the weights_only/save_objects flags, I don't think this necessary.

Questions:

  • Right now, the error message when two tensors are different is not helpful. So, if I create a metamer object with a different target image, I raise an error, but I print out the tensors and their difference, which is ... pretty hard to parse. But feel like that makes more sense than just not printing anything out? I provide more informative info when possible (e.g., the tensors shapes are different).
  • The check that names are the same may fail because someone has restructured their code or updated something (because I include the module in the name). But I think that's fine? In general, it's hard to guarantee things can be saved/loaded across versions (see e.g., sklearn's advice).
    • I was originally going to only raise a warning if the name was different and then e.g., have the optimizer load the state dict and try to continue. But pytorch let's you load another optimizer's state dict. So I can do SGD.load_state_dict(Adam.state_dict()) without a problem, but then when I try to step the optimizer, I get a confusing error message (because it can't find the relevant information). To avoid this problem, I just decided that names have to be stable.

(relevant torch docs: here, and here)

closes #313

@billbrod
Copy link
Member Author

I also removed the ruff actions from ci.yml, because they're being handled by the pre-commit action (and the versions were out of sync, which meant the ci.yml version was failing, preventing me from merging this PR)


"""
super().save(file_path, attrs=None)
if not save_objects:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this if not save_objects? shouldn't including the loss function in save_io_attrs occur when save_objects is True? or am I misunderstanding something? (this is the same for the other synthesis objects)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if save_objects is True, we save the loss function along with all the other attributes (that is, we pickle the whole function). If it's fFalse, then we don't save the function but it's behavior, which is what inclusion in save_io_attrs means (by behavior I mean: given this input, return this output).

regardless, we never save the whole model object because it can get really big

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, so I was misunderstanding save_io_attrs, which is the "new" way of saving things about the objects, thus introducing the break in backwards compatibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, save_io_attrs is one of two new ways of saving info about the objects (the other being save_state_dict_attrs). The breaking of backwards compatibility is because I now explicitly include model in save_io_attrs to check its behavior, whereas previously I would only implicitly check its behavior, by checking whether Metamer._target_representation was identical, which is the cached output of the model on the target image.

What I'm doing now for save_io_attrs (and check_io_attributes in load) is similar to what I was doing before for check_loss_functions in load; the reasoning is the same, but the specifics are different.

@@ -433,13 +433,26 @@ def save(self, file_path: str):
----------
file_path : str
The path to save the Geodesic object to
save_objects :
If True, we use pickle to save all non-model objects (optimizer). To load

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what saves when save_objects=False? would it be useful to have this information in the docstring?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My inclination is no -- the user shouldn't worry about what's saved. In either case, calling load appropriately will give you back the synthesis object at the same state, so it doesn't matter which way you got there. The difference is that save_objects is potentially unsafe. This is what I'm going to try and explain in the doc page I need to write.

This docstring should probably be updated then. Let me write the doc page first, and then see if that helps me find a clearer wording here.

@billbrod
Copy link
Member Author

After some conversations and thinking, I've removed the weights_only and save_objects flags. They only saved a minimum amount of code in an advanced use case (not using the default optimizer), and so are not worth the extra effort.

without flags:

img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met.synthesize(5, optimizer=optimizer)
met.save('test.pt')
met_copy = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met_copy.load('test.pt')
met.synthesize(5, optimizer=optimizer)

with flags:

img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met.synthesize(5, optimizer=optimizer)
met.save('test.pt', save_objects=True)
met_copy = po.synth.Metamer(img, model)
met_copy.load('test.pt', weights_only=False)
met.synthesize(5)

The normal use case (using the default Adam optimizer) looks like:

img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
met.synthesize(5)
met.save('test.pt')
met_copy = po.synth.Metamer(img, model)
met_copy.load('test.pt')
met.synthesize(5)

so in the standard case, the flags don't save the user any effort

@billbrod
Copy link
Member Author

Had to pin sphinx<8.2 due to a bug with nbsphinx. I will be dropping nbsphinx when I switch over to myst/mystnb, so this is temporary.

@billbrod
Copy link
Member Author

Documentation built by flatiron-jenkins at http://docs.plenoptic.org/docs//pulls/317

Copy link

codecov bot commented Feb 28, 2025

Codecov Report

Attention: Patch coverage is 93.78238% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/plenoptic/synthesize/metamer.py 73.33% 4 Missing ⚠️
src/plenoptic/synthesize/synthesis.py 96.19% 4 Missing ⚠️
src/plenoptic/tools/data.py 85.71% 2 Missing ⚠️
src/plenoptic/synthesize/geodesic.py 91.66% 1 Missing ⚠️
src/plenoptic/synthesize/mad_competition.py 90.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
src/plenoptic/synthesize/eigendistortion.py 98.90% <100.00%> (+<0.01%) ⬆️
src/plenoptic/tools/__init__.py 100.00% <100.00%> (ø)
src/plenoptic/tools/io.py 100.00% <100.00%> (ø)
src/plenoptic/synthesize/geodesic.py 95.26% <91.66%> (-2.92%) ⬇️
src/plenoptic/synthesize/mad_competition.py 93.46% <90.00%> (+0.56%) ⬆️
src/plenoptic/tools/data.py 78.91% <85.71%> (+0.71%) ⬆️
src/plenoptic/synthesize/metamer.py 92.09% <73.33%> (-0.35%) ⬇️
src/plenoptic/synthesize/synthesis.py 93.45% <96.19%> (+3.24%) ⬆️
🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

billbrod added 2 commits March 6, 2025 11:16
also makes sure geodesic supports models with 3d outputs
@billbrod
Copy link
Member Author

billbrod commented Mar 6, 2025

Okay, I've finally added all the needed tests. My remaining question is: do I need to update my load docstrings? If so, what should I say? I don't want to expose the details of what's happening for people, instead I tried to make the error messages informative so they point what to do. I currently don't say that loading should be done with the same pytorch and plenoptic versions as saving, but that's general good practice -- does it belong there?

I am going to use sphinx's versionchanged directive to note that the behavior has changed

billbrod added 4 commits March 6, 2025 11:39
we had some old bits of code lying around, which were supporting old
behavior:

- removed code that would allow for the loading of code objects

- removed code that would allow for state_dict attributes set at
initialization. that's possible, but currently not done.
tests were improperly formatted, was apparently using curie for einstein
image
Copy link
Contributor

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Billy,
I have reviewed the machinery of the save/load but I did not focus on the documentation too much not at the tests. I think we can first do pass, on this comment and then I'll take a look at tests as well

@BalzaniEdoardo BalzaniEdoardo self-requested a review March 7, 2025 21:47
Copy link
Contributor

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once implemented the changes discussed, that's approved for me!

Copy link

@sjvenditto sjvenditto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a closer look at the tests, and in general they all look good. I just had a couple of comments/questions:

  • there are a few instances where some parameters aren't being explicitly tested (usually _allowed_range) for a mismatch. however, this parametrize was not added in this PR, so maybe there's a reason they're not being tested?
  • why is the geodesic optimizer being initialized with a private variable? I've also noticed the warning about Geodesic not being robust enough, so it might not be used very often anyway, so maybe this isn't a general issue?

I'd appreciate the clarification as I continue to learn the package! Depending on your responses, they might not lead to any changes -- in which case I'd approve the PR

@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True)
@pytest.mark.parametrize(
"model", ["frontend.LinearNonlinear.nograd"], indirect=True
)
@pytest.mark.parametrize(
"fail",
[False, "img_a", "img_b", "model", "n_steps", "init", "range_penalty"],

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are additional attributes in Geodesic load's check_attributes that aren't being tested here, specifically _allowed_range and pixelfade. Should these be tested as well?

This is also the case for test_save_load for MAD and Metamer, both have _allowed_range in check_attributes, but aren't explicitly being tested for a mismatch

optimizer = None
if optim_opts is not None:
if optim_opts == "Adam":
optimizer = torch.optim.Adam([geod._geodesic])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is more a question than something that needs to be addressed for this PR: are Geodesic optimizers supposed to be initialized from a private variable? how would a user that isn't aware of the private variable initialize the optimizer from the loaded synthesis?

@billbrod
Copy link
Member Author

billbrod commented Mar 10, 2025

I took a closer look at the tests, and in general they all look good. I just had a couple of comments/questions:

there are a few instances where some parameters aren't being explicitly tested (usually _allowed_range) for a mismatch. however, this parametrize was not added in this PR, so maybe there's a reason they're not being tested?

No, this is just a mistake. I'll add tests for those.

why is the geodesic optimizer being initialized with a private variable? I've also noticed the warning about Geodesic not being robust enough, so it might not be used very often anyway, so maybe this isn't a general issue?

Geodesic is deprecated and about to be removed (it will live over in https://github.com/plenoptic-org/geodesics for the time being). But this is an annoyance -- for most objects, the public version of the to-be-optimized tensor is the same as the private (metamer / _metamer, _mad_image / mad_image), I just use the property to avoid users accidentally overwriting it. However, geodesic is the concatenation of image_a, _geodesic, image_b. The geodesics consist of N different images. Two of those are the endpoints (image_a, image_b), which are set by the user at initialization and never change. The remaining N-2 are stored as _geodesic, are the transition between those two endpoints and are what we change during optimization. Thus, when you initialize the optimizer, it needs to be _geodesic, because if you used geodesic, you'd be changing the endpoints. Does that make sense? geodesic is a bit of a convenience variable here, because any time you visualize or compute a diagnostic, you want to make sure you're including the endpoints.

I'm not sure the best way to handle that, because I don't want users to have to interact with a private variable. I suppose I could have a geodesic_with_endpoints or something. If you have a suggestion, I'm all ears, otherwise I might just punt on this for now and open an issue in the geodesic repo so I remember to figure it out at some point.

@sjvenditto
Copy link

Geodesic is deprecated and about to be removed (it will live over in https://github.com/plenoptic-org/geodesics for the time being). But this is an annoyance -- for most objects, the public version of the to-be-optimized tensor is the same as the private (metamer / _metamer, _mad_image / mad_image), I just use the property to avoid users accidentally overwriting it. However, geodesic is the concatenation of image_a, _geodesic, image_b. The geodesics consist of N different images. Two of those are the endpoints (image_a, image_b), which are set by the user at initialization and never change. The remaining N-2 are stored as _geodesic, are the transition between those two endpoints and are what we change during optimization. Thus, when you initialize the optimizer, it needs to be _geodesic, because if you used geodesic, you'd be changing the endpoints. Does that make sense? geodesic is a bit of a convenience variable here, because any time you visualize or compute a diagnostic, you want to make sure you're including the endpoints.

Yeah I noticed that the geodesic property was not the same as _geodesic, which is why I asked!

I'm not sure the best way to handle that, because I don't want users to have to interact with a private variable. I suppose I could have a geodesic_with_endpoints or something. If you have a suggestion, I'm all ears, otherwise I might just punt on this for now and open an issue in the geodesic repo so I remember to figure it out at some point.

My suggestion is also to punt it, especially since Geodesic is deprecated. My thought it that since it's a riskier synthesis method, users that take that risk should be familiar enough with it to know how to initialize the optimizer, whether that is taking the correct slice of geodesic or using the private variable. This can be sorted out once Geodesic becomes more stable

Copy link

@sjvenditto sjvenditto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some more minor comments on the save and load docstrings

@@ -82,6 +82,11 @@ Furthermore:
of the reference metric in a list, ``_reference_metric_loss``, but the
``reference_metric_loss`` attribute converts this list to a tensor before
returning it, as that's how the user will most often want to interact with it.
* All attributes should be initialized at object initialization, though they can
be "False-y" (e.g., an empty list, ``None``). At least one attribute should be
``None`` or an empty list at initialization. which we use when loading to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``None`` or an empty list at initialization. which we use when loading to
``None`` or an empty list at initialization, which we use when loading to

@@ -494,13 +503,14 @@ def load(
):
r"""Load all relevant stuff from a .pt file.
This should be called by an initialized ``Geodesic`` object -- we will

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed from the docstring? A similar message was left in place for Eigendistortion and MAD

@@ -426,17 +436,19 @@ def load(
):
r"""Load all relevant stuff from a .pt file.

This should be called by an initialized ``Metamer`` object -- we will

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to Geodesic, why was this removed?

them as tuples of (name, input_names, outputs). On load, we check that the
initialized object's name hasn't changed, and that when called on the same
inputs, we get the same outputs. Intended for models, metrics, loss
functions. Used to avoid saving callable, which is brittle and unsafe.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
functions. Used to avoid saving callable, which is brittle and unsafe.
functions. Used to avoid saving callables, which is brittle and unsafe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Torch 2.6 change in load behavior
3 participants