-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates save/load behavior for breaking change in torch 2.6 #317
base: main
Are you sure you want to change the base?
Conversation
I also removed the ruff actions from |
src/plenoptic/synthesize/metamer.py
Outdated
|
||
""" | ||
super().save(file_path, attrs=None) | ||
if not save_objects: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this if not save_objects
? shouldn't including the loss function in save_io_attrs
occur when save_objects
is True? or am I misunderstanding something? (this is the same for the other synthesis objects)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if save_objects
is True, we save the loss function along with all the other attributes (that is, we pickle the whole function). If it's fFalse, then we don't save the function but it's behavior, which is what inclusion in save_io_attrs
means (by behavior I mean: given this input, return this output).
regardless, we never save the whole model object because it can get really big
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it, so I was misunderstanding save_io_attrs
, which is the "new" way of saving things about the objects, thus introducing the break in backwards compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, save_io_attrs
is one of two new ways of saving info about the objects (the other being save_state_dict_attrs
). The breaking of backwards compatibility is because I now explicitly include model
in save_io_attrs
to check its behavior, whereas previously I would only implicitly check its behavior, by checking whether Metamer._target_representation
was identical, which is the cached output of the model on the target image.
What I'm doing now for save_io_attrs
(and check_io_attributes
in load) is similar to what I was doing before for check_loss_functions
in load; the reasoning is the same, but the specifics are different.
src/plenoptic/synthesize/geodesic.py
Outdated
@@ -433,13 +433,26 @@ def save(self, file_path: str): | |||
---------- | |||
file_path : str | |||
The path to save the Geodesic object to | |||
save_objects : | |||
If True, we use pickle to save all non-model objects (optimizer). To load |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what saves when save_objects=False
? would it be useful to have this information in the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My inclination is no -- the user shouldn't worry about what's saved. In either case, calling load appropriately will give you back the synthesis object at the same state, so it doesn't matter which way you got there. The difference is that save_objects
is potentially unsafe. This is what I'm going to try and explain in the doc page I need to write.
This docstring should probably be updated then. Let me write the doc page first, and then see if that helps me find a clearer wording here.
After some conversations and thinking, I've removed the without flags: img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met.synthesize(5, optimizer=optimizer)
met.save('test.pt')
met_copy = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met_copy.load('test.pt')
met.synthesize(5, optimizer=optimizer) with flags: img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
optimizer = torch.optim.SGD([met.metamer])
met.synthesize(5, optimizer=optimizer)
met.save('test.pt', save_objects=True)
met_copy = po.synth.Metamer(img, model)
met_copy.load('test.pt', weights_only=False)
met.synthesize(5) The normal use case (using the default Adam optimizer) looks like: img = po.data.einstein()
model = po.simul.Gaussian(30)
po.tools.remove_grad(model)
met = po.synth.Metamer(img, model)
met.synthesize(5)
met.save('test.pt')
met_copy = po.synth.Metamer(img, model)
met_copy.load('test.pt')
met.synthesize(5) so in the standard case, the flags don't save the user any effort |
Had to pin |
Documentation built by flatiron-jenkins at http://docs.plenoptic.org/docs//pulls/317 |
more useful info
this is now handled earlier
private function to check whether tensors have same device, size, dtype and values ( in that order )
also makes sure geodesic supports models with 3d outputs
Okay, I've finally added all the needed tests. My remaining question is: do I need to update my I am going to use sphinx's |
we had some old bits of code lying around, which were supporting old behavior: - removed code that would allow for the loading of code objects - removed code that would allow for state_dict attributes set at initialization. that's possible, but currently not done.
tests were improperly formatted, was apparently using curie for einstein image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Billy,
I have reviewed the machinery of the save/load but I did not focus on the documentation too much not at the tests. I think we can first do pass, on this comment and then I'll take a look at tests as well
there's something with Adam
since we're requiring dtype to be the same, shouldn't be an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once implemented the changes discussed, that's approved for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a closer look at the tests, and in general they all look good. I just had a couple of comments/questions:
- there are a few instances where some parameters aren't being explicitly tested (usually
_allowed_range
) for a mismatch. however, this parametrize was not added in this PR, so maybe there's a reason they're not being tested? - why is the geodesic optimizer being initialized with a private variable? I've also noticed the warning about Geodesic not being robust enough, so it might not be used very often anyway, so maybe this isn't a general issue?
I'd appreciate the clarification as I continue to learn the package! Depending on your responses, they might not lead to any changes -- in which case I'd approve the PR
@pytest.mark.parametrize("model", ["frontend.OnOff.nograd"], indirect=True) | ||
@pytest.mark.parametrize( | ||
"model", ["frontend.LinearNonlinear.nograd"], indirect=True | ||
) | ||
@pytest.mark.parametrize( | ||
"fail", | ||
[False, "img_a", "img_b", "model", "n_steps", "init", "range_penalty"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are additional attributes in Geodesic load
's check_attributes
that aren't being tested here, specifically _allowed_range
and pixelfade
. Should these be tested as well?
This is also the case for test_save_load
for MAD and Metamer, both have _allowed_range
in check_attributes
, but aren't explicitly being tested for a mismatch
optimizer = None | ||
if optim_opts is not None: | ||
if optim_opts == "Adam": | ||
optimizer = torch.optim.Adam([geod._geodesic]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is more a question than something that needs to be addressed for this PR: are Geodesic optimizers supposed to be initialized from a private variable? how would a user that isn't aware of the private variable initialize the optimizer from the loaded synthesis?
No, this is just a mistake. I'll add tests for those.
Geodesic is deprecated and about to be removed (it will live over in https://github.com/plenoptic-org/geodesics for the time being). But this is an annoyance -- for most objects, the public version of the to-be-optimized tensor is the same as the private ( I'm not sure the best way to handle that, because I don't want users to have to interact with a private variable. I suppose I could have a |
Yeah I noticed that the
My suggestion is also to punt it, especially since Geodesic is deprecated. My thought it that since it's a riskier synthesis method, users that take that risk should be familiar enough with it to know how to initialize the optimizer, whether that is taking the correct slice of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some more minor comments on the save and load docstrings
@@ -82,6 +82,11 @@ Furthermore: | |||
of the reference metric in a list, ``_reference_metric_loss``, but the | |||
``reference_metric_loss`` attribute converts this list to a tensor before | |||
returning it, as that's how the user will most often want to interact with it. | |||
* All attributes should be initialized at object initialization, though they can | |||
be "False-y" (e.g., an empty list, ``None``). At least one attribute should be | |||
``None`` or an empty list at initialization. which we use when loading to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``None`` or an empty list at initialization. which we use when loading to | |
``None`` or an empty list at initialization, which we use when loading to |
@@ -494,13 +503,14 @@ def load( | |||
): | |||
r"""Load all relevant stuff from a .pt file. | |||
This should be called by an initialized ``Geodesic`` object -- we will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this removed from the docstring? A similar message was left in place for Eigendistortion and MAD
@@ -426,17 +436,19 @@ def load( | |||
): | |||
r"""Load all relevant stuff from a .pt file. | |||
|
|||
This should be called by an initialized ``Metamer`` object -- we will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to Geodesic, why was this removed?
them as tuples of (name, input_names, outputs). On load, we check that the | ||
initialized object's name hasn't changed, and that when called on the same | ||
inputs, we get the same outputs. Intended for models, metrics, loss | ||
functions. Used to avoid saving callable, which is brittle and unsafe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
functions. Used to avoid saving callable, which is brittle and unsafe. | |
functions. Used to avoid saving callables, which is brittle and unsafe. |
As discussed in the related issue (#313), Torch 2.6 introduced a breaking change in how load works, by setting
weights_only=True
by default and being much more conservative. This PR makes plenoptic compatible with that change.Our previous behavior was incompatible with the changes because we were saving two types of objects:
We also save a bunch of python primitives (floats, ints, strings) and tensors, but these are fine.
Now, we have modified
save()
so that object attributes are placed into three categories:save_attrs
: primitives and tensors that can be saved directly. These are not expicitly set, but they're all attributes that aren't included in the next two categories.save_io_attrs
: functions or callable objects that accept and return tensors (e.g., loss functions, metrics, models). these are state-less, in that nothing changes over the course of synthesis. For these objects, we save a tuple with their name (using_get_name
insynthesis.py
), the name of one or more other attributes of the object that can be passed as inputs (e.g.,_image
,_metamer
) and the output when called on those objects.save_state_dict_attrs
: pytorch objects with a state_dict that may change over the course of synthesis (e.g., optimizers whose learning rate may change). We save a tuple with their name and their state_dict.On load, we:
range_penalty_lambda
,image
).synthesize
).For all synthesis objects other thanEDIT: removed these flags, see later comment.Eigendistortion
(for which it's not relevant), we also provide thesave_objects
andweights_only
boolean flags for save and load, respectively. Ifsave_objects=True
, we save the callables (loss functions, metrics, schedulers, optimizers, but NOT models, sincetorch.nn.Module
can be very large). In this case, they must then passweights_only=False
to load, to override the new default torch behavior.This makes our save/load behavior not backwards compatible because of how we handle the model. Previously, I was implicitly checking it (by checking for e.g.,
Metamer.target_representation
), but now I treat it the same as the loss function. However, this current behavior works with (at least) torch 2.5 and 2.6, so that's good.Additionally:
synthesize
is called, scheduler can be non-None, every subsequent time, must be None.I still need to:
Questions:
SGD.load_state_dict(Adam.state_dict())
without a problem, but then when I try to step the optimizer, I get a confusing error message (because it can't find the relevant information). To avoid this problem, I just decided that names have to be stable.(relevant torch docs: here, and here)
closes #313