diff --git a/classy_vision/losses/classy_loss.py b/classy_vision/losses/classy_loss.py index 94f9a26b21..1be8afa7c4 100644 --- a/classy_vision/losses/classy_loss.py +++ b/classy_vision/losses/classy_loss.py @@ -42,3 +42,44 @@ def forward(self, output, target): Refer to :class:`torch.nn.Module` for more details. """ raise NotImplementedError + + def get_optimizer_params(self, bn_weight_decay=False): + """Gets optimizer params. + + The default implementation is very simple. Most losses have no learned + parameters, so this is rarely needed. + """ + params = [ + param for param in self.parameters(recurse=True) if param.requires_grad + ] + return {"regularized_params": params, "unregularized_params": []} + + def get_classy_state(self) -> Dict[str, Any]: + """Get the state of the ClassyLoss. + + The returned state is used for checkpointing. Note that most losses are + stateless and do not need to save any state. + + Returns: + A state dictionary containing the state of the loss. + """ + return self.state_dict() + + def set_classy_state(self, state: Dict[str, Any]) -> None: + """Set the state of the ClassyLoss. + + Args: + state_dict: The state dictionary. Must be the output of a call to + :func:`get_classy_state`. + + This is used to load the state of the loss from a checkpoint. Note + that most losses are stateless and do not need to load any state. + """ + return self.load_state_dict(state) + + def has_learned_parameters(self) -> bool: + """Does this loss have learned parameters?""" + for _, params in self.get_optimizer_params().items(): + if len(params) > 0: + return True + return False diff --git a/classy_vision/optim/adam.py b/classy_vision/optim/adam.py index 211f2fd552..aededf44ab 100644 --- a/classy_vision/optim/adam.py +++ b/classy_vision/optim/adam.py @@ -29,8 +29,8 @@ def __init__( self.parameters.weight_decay = weight_decay self.parameters.amsgrad = amsgrad - def init_pytorch_optimizer(self, model) -> None: - super().init_pytorch_optimizer(model) + def init_pytorch_optimizer(self, model, **kwargs) -> None: + super().init_pytorch_optimizer(model, **kwargs) self.optimizer = torch.optim.Adam( self.param_groups_override, lr=self.parameters.lr, diff --git a/classy_vision/optim/classy_optimizer.py b/classy_vision/optim/classy_optimizer.py index fa95aebc5f..1f9bca2a50 100644 --- a/classy_vision/optim/classy_optimizer.py +++ b/classy_vision/optim/classy_optimizer.py @@ -4,9 +4,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union import torch +from classy_vision.losses import ClassyLoss from classy_vision.models import ClassyModel from .param_scheduler import ClassyParamScheduler, UpdateInterval @@ -61,24 +62,9 @@ def set_param_schedulers( ) return self - def _validate_and_get_optimizer_params(self, model: ClassyModel) -> Dict[str, Any]: - """ - Validate and return the optimizer params. - - The optimizer params are fetched from - :fun:`models.ClassyModel.get_optimizer_params`. - - Args: - model: The model to get the params from. - - Returns: - A dict containing "regularized_params" and "unregularized_params". - Weight decay will only be applied to "regularized_params". - """ - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - optimizer_params = model.module.get_optimizer_params() - else: - optimizer_params = model.get_optimizer_params() + @staticmethod + def _validate_optimizer_params(model: Union[ClassyLoss, ClassyModel]): + optimizer_params = model.get_optimizer_params() assert isinstance(optimizer_params, dict) and set(optimizer_params.keys()) == { "regularized_params", @@ -100,6 +86,39 @@ def _validate_and_get_optimizer_params(self, model: ClassyModel) -> Dict[str, An return optimizer_params + def _validate_and_get_optimizer_params( + self, model: ClassyModel, loss: Optional[Union[ClassyLoss, Any]] = None + ) -> Dict[str, Any]: + """ + Validate and return the optimizer params. + + The optimizer params are fetched from + :fun:`models.ClassyModel.get_optimizer_params`. + + Args: + model: The model to get the params from. + loss: The loss. If present, and a ClassyLoss, then the loss may + also contirbute parameters. + + Returns: + A dict containing "regularized_params" and "unregularized_params". + Weight decay will only be applied to "regularized_params". + """ + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + + optimizer_params = self._validate_optimizer_params(model) + + if loss is not None and isinstance(loss, ClassyLoss): + loss_params = self._validate_optimizer_params(loss) + # Merge loss and model params. + optimizer_params = { + key: value + loss_params[key] + for (key, value) in optimizer_params.items() + } + + return optimizer_params + @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassyOptimizer": """Instantiates a ClassyOptimizer from a configuration. @@ -112,7 +131,9 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassyOptimizer": """ raise NotImplementedError - def init_pytorch_optimizer(self, model: ClassyModel) -> None: + def init_pytorch_optimizer( + self, model: ClassyModel, loss: Optional[Union[ClassyLoss, Any]] = None + ) -> None: """ Initialize the underlying :class:`torch.optim.Optimizer` instance. @@ -129,7 +150,7 @@ def init_pytorch_optimizer(self, model: ClassyModel) -> None: This should called only after the model has been moved to the correct device. """ - self.optimizer_params = self._validate_and_get_optimizer_params(model) + self.optimizer_params = self._validate_and_get_optimizer_params(model, loss) param_groups_override = [] self.contains_unregularized_params = False diff --git a/classy_vision/optim/rmsprop.py b/classy_vision/optim/rmsprop.py index 23627bfe3d..a47d01bea1 100644 --- a/classy_vision/optim/rmsprop.py +++ b/classy_vision/optim/rmsprop.py @@ -32,8 +32,8 @@ def __init__( self.parameters.eps = eps self.parameters.centered = centered - def init_pytorch_optimizer(self, model): - super().init_pytorch_optimizer(model) + def init_pytorch_optimizer(self, model, **kwargs): + super().init_pytorch_optimizer(model, **kwargs) self.optimizer = torch.optim.RMSprop( self.param_groups_override, lr=self.parameters.lr, diff --git a/classy_vision/optim/rmsprop_tf.py b/classy_vision/optim/rmsprop_tf.py index 8b08edfc17..640ee5f6ac 100644 --- a/classy_vision/optim/rmsprop_tf.py +++ b/classy_vision/optim/rmsprop_tf.py @@ -168,8 +168,8 @@ def __init__( self.parameters.eps = eps self.parameters.centered = centered - def init_pytorch_optimizer(self, model): - super().init_pytorch_optimizer(model) + def init_pytorch_optimizer(self, model, **kwargs): + super().init_pytorch_optimizer(model, **kwargs) self.optimizer = RMSpropTFOptimizer( self.param_groups_override, lr=self.parameters.lr, diff --git a/classy_vision/optim/sgd.py b/classy_vision/optim/sgd.py index 51a9ff2c7f..5f9ddd96b9 100644 --- a/classy_vision/optim/sgd.py +++ b/classy_vision/optim/sgd.py @@ -27,8 +27,8 @@ def __init__( self.parameters.weight_decay = weight_decay self.parameters.nesterov = nesterov - def init_pytorch_optimizer(self, model): - super().init_pytorch_optimizer(model) + def init_pytorch_optimizer(self, model, **kwargs): + super().init_pytorch_optimizer(model, **kwargs) self.optimizer = torch.optim.SGD( self.param_groups_override, lr=self.parameters.lr, diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 985618f9fb..a4d4e374f7 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -478,15 +478,14 @@ def prepare( # move the model and loss to the right device if use_gpu: - self.loss.cuda() - self.base_model = copy_model_to_gpu(self.base_model) + self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss) else: self.loss.cpu() self.base_model.cpu() # initialize the pytorch optimizer now since the model has been moved to # the appropriate device - self.optimizer.init_pytorch_optimizer(self.base_model) + self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss) classy_state_dict = ( None @@ -528,6 +527,11 @@ def init_distributed_data_parallel_model(self): self.distributed_model = init_distributed_data_parallel_model( self.base_model, broadcast_buffers=broadcast_buffers ) + if isinstance(self.loss, ClassyLoss) and self.loss.has_learned_parameters(): + logging.info("Initializing distributed loss") + self.loss = init_distributed_data_parallel_model( + self.loss, broadcast_buffers=broadcast_buffers + ) @property def where(self): @@ -569,7 +573,10 @@ def get_classy_state(self, deep_copy: bool = False): "num_updates": self.num_updates, "losses": self.losses, "hooks": {hook.name(): hook.get_classy_state() for hook in self.hooks}, + "loss": {}, } + if isinstance(self.loss, ClassyLoss): + classy_state_dict["loss"] = self.loss.get_classy_state() if deep_copy: classy_state_dict = copy.deepcopy(classy_state_dict) return classy_state_dict @@ -592,6 +599,9 @@ def set_classy_state(self, state): self.base_model.set_classy_state(state["base_model"]) self.optimizer.set_classy_state(state["optimizer"]) + if state.get("loss") and isinstance(self.loss, ClassyLoss): + self.loss.set_classy_state(state["loss"]) + for hook in self.hooks: # we still want to be able to run when new hooks are added or old # hooks are removed @@ -821,6 +831,7 @@ def _set_model_train_mode(self): """ phase = self.phases[self.phase_idx] self.base_model.train(phase["train"]) + self.loss.train(phase["train"]) if ( self.broadcast_buffers_mode == BroadcastBuffersMode.BEFORE_EVAL