-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix problematic behavior of optimizer/scheduler in FeatureInversionTask #101
base: dev
Are you sure you want to change the base?
Conversation
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Based on Otsuka-san's suggestion, I have revised the type definitions as follows. build_optimizer_factory: (type[Optimizer], _GetParamsFnType) -> _OptimizerFactoryType
- _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> Iterator[Parameter]
+ _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> _ParamsT
_OptimizerFactoryType: TypeAlias = (BaseGenerator, BaseLatent) -> Optimizer
+ _ParamsT: TypeAlias = Iterable[Tensor] | Iterable[Dict[str, Any]] | Iterable[Tuple[str, Tensor]] Reasons behind this modificationPrevious type annotations were not compatible with the use of the optimizer_factory = build_optimizer_factory(
SGD,
get_params_fn=lambda generator, latent: [
{"params": latent.parameters(), "lr": latent_lr},
{"params": generator.parameters(), "lr" generator_lr},
],
lr=base_lr, momentum=0.9
) The Why we redefined the same concept in our codebase instead of just importing it from PyTorch?I decided to define the bdpy/bdpy/recon/torch/modules/optimizer.py Lines 14 to 18 in c23afe7
The name of this type was Note on the type definition of
|
Problem
Current implementation of
FeatureInversionTask
has several limitations/problems in the use of optimizer/scheduler. Here are the concrete examples:Initialization using param_groups works only one time
Cannot use a learning rate scheduler
Cause
The cause of the problem is in the implementation of
reset_states()
:bdpy/bdpy/recon/torch/task/inversion.py
Lines 216 to 226 in 9ffe7bc
Originally this method was implemented based on the following assumptions:
In reality, neither of these assumptions were true. In addition, since optimizers generally have dependencies on generator and latent instances, and learning rate schedulers have dependencies on optimizer instances, when any of these dependencies are re-instantiated, the references need to be replaced accordingly.
Solution
Instead of receiving the instances of the optimizer and learning rate scheduler themselves, FeatureInversionTask receives the factory method for creating instances. Following is the example use of the newly designed API:
Breaking changes in API
FeatureInversionTask
takesoptimizer_factory: (BaseGenerator, BaseLatent) -> Optimizer
instead ofoptimizer: Optimizer
as an input argumentFeatureInversionTask
takesscheduler_factory: Optimizer -> LRScheduler
instead ofscheduler: LRScheduler
as an input argument