-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lightning checkpointing #9
Comments
How to save checkpoint in lightning?Note that: 아래 코드 분석은 lightning 2.0.7 버전을 기준으로 분석한 결과이기 때문에 차후 경로 및 소스코드에 변경이 있을 수 있음. 보통 학습이 종료된 이후
# lightning.pytorch.callbacks.model_checkpoint.py
class ModelCheckpoint(Checkpoint):
...
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.save_checkpoint(filepath, self.save_weights_only)
self._last_global_step_saved = trainer.global_step
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self)) 즉, 중요한 것은 # lightning.pytorch.trainer.trainer.py
class Trainer:
...
def save_checkpoint(
self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
) -> None:
r"""Runs routine to create a checkpoint.
Args:
filepath: Path where checkpoint is saved.
weights_only: If ``True``, will only save the model weights.
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin
Raises:
AttributeError:
If the model is not attached to the Trainer before calling this method.
"""
if self.model is None:
raise AttributeError(
"Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
)
checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
self.strategy.barrier("Trainer.save_checkpoint") 여기서 # lightning.pytorch.trainer.trainer.py
class Trainer:
@_defaults_from_env_vars
def __init__(
self,
...
) -> None:
...
self._checkpoint_connector = _CheckpointConnector(self)
...
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
...
def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Args:
weights_only: saving model weights only
Return:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
'state_dict': Model's state_dict (e.g. network weights)
precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only
CHECKPOINT_HYPER_PARAMS_NAME:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__qualname__: pl DataModule's state
}
"""
trainer = self.trainer
model = trainer.lightning_module
datamodule = trainer.datamodule
checkpoint = {
# the epoch and global step are saved for compatibility but they are not relevant for restoration
"epoch": trainer.current_epoch,
"global_step": trainer.global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
}
if not weights_only:
# dump callbacks
checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer)
optimizer_states = []
for i, optimizer in enumerate(trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = trainer.strategy.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
checkpoint["optimizer_states"] = optimizer_states
# dump lr schedulers
lr_schedulers = []
for config in trainer.lr_scheduler_configs:
lr_schedulers.append(config.scheduler.state_dict())
checkpoint["lr_schedulers"] = lr_schedulers
# precision plugin
prec_plugin = trainer.precision_plugin
prec_plugin_state_dict = prec_plugin.state_dict()
if prec_plugin_state_dict:
checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
prec_plugin.on_save_checkpoint(checkpoint)
# dump hyper-parameters
for obj in (model, datamodule):
if obj and obj.hparams:
if hasattr(obj, "_hparams_name"):
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name
# dump arguments
if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container):
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams)
else:
checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams)
# dump stateful datamodule
if datamodule is not None:
datamodule_state_dict = call._call_lightning_datamodule_hook(trainer, "state_dict")
if datamodule_state_dict:
checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict
# on_save_checkpoint hooks
if not weights_only:
# if state is returned from callback's on_save_checkpoint
# it overrides the returned state from callback's state_dict
# support for returning state in on_save_checkpoint
# will be removed in v1.8
call._call_callbacks_on_save_checkpoint(trainer, checkpoint)
call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint)
return checkpoint 위 코드를 분석하여 아래의 사실을 확인할 수 있었다.
현재까지의 분석으로
이를 hf-style로 구현할 방법으로는,
어떤 방식으로 구현할지는 |
How to load checkpoint in lightning?chatgpt에게 lightning에서 checkpoint를 loading하는 방법들에 대해 물어봤다.
model = MyModel.load_from_checkpoint(checkpoint_path="path/to/checkpoint.ckpt")
trainer = pl.Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
trainer.fit(model, dataloader)
checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, dataloader)
# 나중에 최상의 체크포인트 로드
best_model_path = checkpoint_callback.best_model_path
model = MyModel.load_from_checkpoint(best_model_path)
checkpoint = torch.load("path/to/checkpoint.ckpt")
model.load_state_dict(checkpoint['state_dict']) 재차 chatgpt에게 물어본 결과, pl 1.x 버전에서 동작하는 방법들이라고 언급했기에 2.x에서의 사용 여부와 docs를 추가적으로 살펴보며 분석하고자 한다. |
우선, 3번의 Issue 9006에서 Trainer의 arguments를 최소화하고자 수정을 한 것으로 보인다. 위를 다시 정리하면, 1. naive하게 본 레포는 |
|
|
목적
The text was updated successfully, but these errors were encountered: