Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix problematic behavior of optimizer/scheduler in FeatureInversionTask #101

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

ganow
Copy link
Contributor

@ganow ganow commented Feb 18, 2025

Problem

Current implementation of FeatureInversionTask has several limitations/problems in the use of optimizer/scheduler. Here are the concrete examples:

Initialization using param_groups works only one time

optimizer = optim.SGD([
  {"params": latent.parameters(), "lr": latent_lr},
  {"params": generator.parameters(), "lr" generator_lr},
], lr=base_lr)
task = FeatureInversionTask(encoder, generator, latent, critic, optimizer, num_iterations=100)
reconstructed = task(target_features)  # uses latent_lr for a latent module and generator_lr for a generator module
task.reset_states()
reconstructed_2 = task(target_features)  # uses base_lr for both latent module and generator module

Cannot use a learning rate scheduler

optimizer = optim.SGD(latent.parameters(), lr=0.01)
scheduler = ExponentialLR(optimizer, gamma=0.9)
task = FeatureInversionTask(encoder, generator, latent, critic, optimizer, scheduler=scheduler, num_iterations=100)
task.reset_states()
reconstructed = task(target_features)  # learning rate scheduler does not work properly

Cause

The cause of the problem is in the implementation of reset_states():

def reset_states(self) -> None:
"""Reset the state of the task."""
self._generator.reset_states()
self._latent.reset_states()
self._optimizer = self._optimizer.__class__(
chain(
self._generator.parameters(),
self._latent.parameters(),
),
**self._optimizer.defaults,
)

Originally this method was implemented based on the following assumptions:

  • We can dynamically provide a consistent way of re-initializing optimizers only from their instances (L220-226)
  • Learning rate schedulers are not needed to be re-initialized

In reality, neither of these assumptions were true. In addition, since optimizers generally have dependencies on generator and latent instances, and learning rate schedulers have dependencies on optimizer instances, when any of these dependencies are re-instantiated, the references need to be replaced accordingly.

Solution

Instead of receiving the instances of the optimizer and learning rate scheduler themselves, FeatureInversionTask receives the factory method for creating instances. Following is the example use of the newly designed API:

encoder = build_encoder(...)
generator = build_generator(...)
latent = ArbitraryLatent(...)
critic = TargetNormalizedMSE(...)

optimizer_factory = build_optimizer_factory(
    SGD,
    get_params_fn=lambda generator, latent: [
        {"params": latent.parameters(), "lr": latent_lr},
        {"params": generator.parameters(), "lr" generator_lr},
    ],
    lr=base_lr, momentum=0.9
)
scheduler_factory = build_scheduler_factory(ExponentialLR, gamma=0.9)

task = FeatureInversionTask(
    encoder, generator, latent, critic,
    optimizer_factory, scheduler_factory,
    num_iterations=100,
)
reconstructed = task(target_features)

...
## In reset_states() of FeatureInversionTask
optimizer = optimizer_factory(generator, latent)
scheduler = scheduler_factory(optimizer)

Breaking changes in API

  • FeatureInversionTask takes optimizer_factory: (BaseGenerator, BaseLatent) -> Optimizer instead of optimizer: Optimizer as an input argument
  • FeatureInversionTask takes scheduler_factory: Optimizer -> LRScheduler instead of scheduler: LRScheduler as an input argument

Copy link

github-actions bot commented Feb 18, 2025

Tests Skipped Failures Errors Time
109 0 💤 0 ❌ 1 🔥 11.401s ⏱️

Copy link

@KenyaOtsuka KenyaOtsuka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@ganow ganow marked this pull request as ready for review February 19, 2025 04:19
@ganow
Copy link
Contributor Author

ganow commented Feb 20, 2025

Based on Otsuka-san's suggestion, I have revised the type definitions as follows.

   build_optimizer_factory: (type[Optimizer], _GetParamsFnType) -> _OptimizerFactoryType
-  _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> Iterator[Parameter]
+  _GetParamsFnType: TypeAlias = (BaseGenerator, BaseLatent) -> _ParamsT
   _OptimizerFactoryType: TypeAlias = (BaseGenerator, BaseLatent) -> Optimizer
+  _ParamsT: TypeAlias = Iterable[Tensor] | Iterable[Dict[str, Any]] | Iterable[Tuple[str, Tensor]]

Reasons behind this modification

Previous type annotations were not compatible with the use of the build_optimizer_factory like the following:

optimizer_factory = build_optimizer_factory(
    SGD,
    get_params_fn=lambda generator, latent: [
        {"params": latent.parameters(), "lr": latent_lr},
        {"params": generator.parameters(), "lr" generator_lr},
    ],
    lr=base_lr, momentum=0.9
)

The get_params_fn will return the list of the dictionary while previous our type annotation assumes the return type should be the iterable of torch.nn.Parameter. So we aligned the type annotation with the one defined in pytorch.

Why we redefined the same concept in our codebase instead of just importing it from PyTorch?

I decided to define the _ParamsT in our codebase instead of importing the same concept in PyTorch.

# NOTE: The definition of `_ParamsT` is the same as in `torch.optim.optimizer`
# in torch>=2.2.0. We define it here for compatibility with older versions.
_ParamsT: TypeAlias = Union[
Iterable[Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, Tensor]]
]

The name of this type was _params_t before PyTorch v2.2.0, but the name of this concept has been changed as ParamsT in v2.2.0. So I decided to redefine this because I did not find a consistent way of importing this concept independent from the PyTorch version.

Note on the type definition of _ParamsT

In PyTorch, the optimizer can now accept named parameters from v2.6.0. In line with this, the definition of ParamsT has been changed in PyTorch as follows in v2.6.0 and later.

-  ParamsT: TypeAlias = Iterable[Tensor] | Iterable[Dict[str, Any]]]
+  ParamsT: TypeAlias = Iterable[Tensor] | Iterable[Dict[str, Any]] | Iterable[Tuple[str, Tensor]]

This PR uses a type definition that reflects this latest change. Therefore, such an example code will pass the type checking but will raise an exception at runtime if PyTorch earlier than v2.6.0 is installed in the environment:

def get_params_fn(generator: BaseGenerator, latent: BaseLatent):
    return chain(generator.named_parameters(), latent.named_parameters())

optimizer_factory = build_optimizer_factory(SGD, get_params_fn=get_params_fn, lr=0.1)

Therefore, if we use the code base associated with this change in an environment where PyTorch earlier than v2.6.0 is installed. As I couldn't think of a consistent type definition idea that appropriately works on different PyTroch versions, I adopted this implementation as a tentative decision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants