Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lightning checkpointing #9

Open
jinmang2 opened this issue Aug 18, 2023 · 5 comments
Open

lightning checkpointing #9

jinmang2 opened this issue Aug 18, 2023 · 5 comments
Assignees

Comments

@jinmang2
Copy link
Owner

목적

  • lightning source code를 뜯고 save/load checkpoint 동작을 파악
  • hf-style로 구현
@jinmang2 jinmang2 self-assigned this Aug 18, 2023
@jinmang2
Copy link
Owner Author

jinmang2 commented Aug 18, 2023

How to save checkpoint in lightning?

Note that: 아래 코드 분석은 lightning 2.0.7 버전을 기준으로 분석한 결과이기 때문에 차후 경로 및 소스코드에 변경이 있을 수 있음.

보통 학습이 종료된 이후 pl.Trainersave_checkpoint method를 활용하여 저장하거나 callbacks에 pl.callbacks.ModelCheckpoint를 활용해서 저장하게 된다.

pl.callbacks.ModelCheckpoint에선 last/non-monitor/update-best 등 세팅된 인자를 기반으로 어느 시점이든 checkpoint를 저장하게 되는데, 전부 공통적으로 아래 method를 활용한다.

# lightning.pytorch.callbacks.model_checkpoint.py
class ModelCheckpoint(Checkpoint):
...
    def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
        trainer.save_checkpoint(filepath, self.save_weights_only)

        self._last_global_step_saved = trainer.global_step

        # notify loggers
        if trainer.is_global_zero:
            for logger in trainer.loggers:
                logger.after_save_checkpoint(proxy(self))

즉, 중요한 것은 pl.Trainersave_checkpoint method. 이 또한 뜯어보면,

# lightning.pytorch.trainer.trainer.py
class Trainer:
...
    def save_checkpoint(
        self, filepath: _PATH, weights_only: bool = False, storage_options: Optional[Any] = None
    ) -> None:
        r"""Runs routine to create a checkpoint.

        Args:
            filepath: Path where checkpoint is saved.
            weights_only: If ``True``, will only save the model weights.
            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin

        Raises:
            AttributeError:
                If the model is not attached to the Trainer before calling this method.

        """
        if self.model is None:
            raise AttributeError(
                "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
                " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
            )
        checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
        self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
        self.strategy.barrier("Trainer.save_checkpoint")

여기서 self.strategy는 ddp 등의 분산학습 전략을 의미하는 것으로 보이고 실제로 checkpoint를 받아오는 것은 self._checkpoint_connectordump_checkpoint method로 확인된다. 해당 코드에서도 중요한 부분만 추출하여 분석하면,

# lightning.pytorch.trainer.trainer.py
class Trainer:
    @_defaults_from_env_vars
    def __init__(
        self,
        ...
    ) -> None:
        ...
        self._checkpoint_connector = _CheckpointConnector(self)
        ...
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def dump_checkpoint(self, weights_only: bool = False) -> dict:
        """Creating a model checkpoint dictionary object from various component states.

        Args:
            weights_only: saving model weights only
        Return:
            structured dictionary: {
                'epoch':                     training epoch
                'global_step':               training global step
                'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint
                'callbacks':                 "callback specific state"[] # if not weights_only
                'optimizer_states':          "PT optim's state_dict"[]   # if not weights_only
                'lr_schedulers':             "PT sched's state_dict"[]   # if not weights_only
                'state_dict':                Model's state_dict (e.g. network weights)
                precision_plugin.__class__.__qualname__:  precision plugin state_dict # if not weights_only
                CHECKPOINT_HYPER_PARAMS_NAME:
                CHECKPOINT_HYPER_PARAMS_KEY:
                CHECKPOINT_HYPER_PARAMS_TYPE:
                something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
                LightningDataModule.__class__.__qualname__: pl DataModule's state
            }

        """
        trainer = self.trainer
        model = trainer.lightning_module
        datamodule = trainer.datamodule

        checkpoint = {
            # the epoch and global step are saved for compatibility but they are not relevant for restoration
            "epoch": trainer.current_epoch,
            "global_step": trainer.global_step,
            "pytorch-lightning_version": pl.__version__,
            "state_dict": self._get_lightning_module_state_dict(),
            "loops": self._get_loops_state_dict(),
        }

        if not weights_only:
            # dump callbacks
            checkpoint["callbacks"] = call._call_callbacks_state_dict(trainer)

            optimizer_states = []
            for i, optimizer in enumerate(trainer.optimizers):
                # Rely on accelerator to dump optimizer state
                optimizer_state = trainer.strategy.optimizer_state(optimizer)
                optimizer_states.append(optimizer_state)

            checkpoint["optimizer_states"] = optimizer_states

            # dump lr schedulers
            lr_schedulers = []
            for config in trainer.lr_scheduler_configs:
                lr_schedulers.append(config.scheduler.state_dict())
            checkpoint["lr_schedulers"] = lr_schedulers

            # precision plugin
            prec_plugin = trainer.precision_plugin
            prec_plugin_state_dict = prec_plugin.state_dict()
            if prec_plugin_state_dict:
                checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict
            prec_plugin.on_save_checkpoint(checkpoint)

        # dump hyper-parameters
        for obj in (model, datamodule):
            if obj and obj.hparams:
                if hasattr(obj, "_hparams_name"):
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_NAME] = obj._hparams_name
                # dump arguments
                if _OMEGACONF_AVAILABLE and isinstance(obj.hparams, Container):
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = obj.hparams
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_TYPE] = type(obj.hparams)
                else:
                    checkpoint[obj.CHECKPOINT_HYPER_PARAMS_KEY] = dict(obj.hparams)

        # dump stateful datamodule
        if datamodule is not None:
            datamodule_state_dict = call._call_lightning_datamodule_hook(trainer, "state_dict")
            if datamodule_state_dict:
                checkpoint[datamodule.__class__.__qualname__] = datamodule_state_dict

        # on_save_checkpoint hooks
        if not weights_only:
            # if state is returned from callback's on_save_checkpoint
            # it overrides the returned state from callback's state_dict
            # support for returning state in on_save_checkpoint
            # will be removed in v1.8
            call._call_callbacks_on_save_checkpoint(trainer, checkpoint)
        call._call_lightning_module_hook(trainer, "on_save_checkpoint", checkpoint)
        return checkpoint

위 코드를 분석하여 아래의 사실을 확인할 수 있었다.

  1. 코드 제일 위쪽에서 checkpoint dict가 선언되면서 epoch, global_step, pytorch-lightning_version, state_dict, loops가 계산된다.
  2. weights_only option이 False일 경우 callbacks, optimizer_states, lr_schedulers, precision_plugin 등 또한 checkpoint dict에 넘겨진다.
  3. pl.LightningModule, pl.LightningDataModule의 hparams또한 해당 object의 CHECKPOINT_HYPER_PARAMS_{}.fotmat(NAME | KEY | TYPE)의 key에 기록된다
  4. lightning.pytorch.trainer.call.py_call_callbacks_on_save_checkpoint로 trainer에 등록된 checkpoint들의 on_save_checkpoint mehtod를 전부 수행해준다.
  5. lightning.pytorch.trainer.call.py_call_lightning_module_hook으로 lightningmodule에서 CheckpointHooks를 상속받아서 가지고 있는 on_save_checkpoint method을 override해서 저장할 내역을 작성했다면 해당 custom 함수를 실행시켜준다.

현재까지의 분석으로

  1. ckpt의 state_dict는 _CheckpointConnector가 기본적으로 가지고 있는 _get_lightning_module_state_dict method로 얻어온다.
    • _CheckpointConnector는 pl.Trainer`의 인자로 줄 수 없다. 상속해서 생성자에서 건드려야한다.
  2. lightningmodule에서 정의한 on_save_checkpoint는 제일 마지막에 수행된다. 유사하게 callbacks들의 on_save_checkpoint또한 마지막에 수행된다.

이를 hf-style로 구현할 방법으로는,

  1. pl.Trainer를 상속한 HuggingfaceTrainer를 만들고 _CheckpointConnector의 dump_checkpoint를 수정한 객체를 생성자에 주입한다.
  2. pl.Trainerself._checkpoint_connectordump_checkpoint를 wrapping하여 원하는 동작을 수행하도록 수정한다.
  3. lightningmodule의 on_save_checkpoint method에서 기존 checkpoint의 state_dict를 pop하고 self.model.save_pretrained를 수행해준다. checkpoint에는 path만 넘겨준다.

어떤 방식으로 구현할지는 load_checkpoint를 어떻게 수행하는지에 달렸다.

@jinmang2
Copy link
Owner Author

jinmang2 commented Aug 18, 2023

How to load checkpoint in lightning?

chatgpt에게 lightning에서 checkpoint를 loading하는 방법들에 대해 물어봤다.

  1. load_from_checkpoint 클래스 메서드 사용
  • 이 방법은 LightningModule에서 제공하는 클래스 메서드로, 체크포인트 파일 경로를 직접 지정하여 모델을 로드할 수 있음.
model = MyModel.load_from_checkpoint(checkpoint_path="path/to/checkpoint.ckpt")
  1. Trainerresume_from_checkpoint 매개변수 사용
  • 학습을 이어서 진행하려는 경우, Trainer 객체를 생성할 때, resume_from_checkpoint 매개변수를 사용하여 checkpoint file path를 지정할 수 있음
trainer = pl.Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
trainer.fit(model, dataloader)
  1. ModelCheckpoint 콜백과 함께 사용
  • ModelCheckpoint 콜백을 사용하면, 최상의 체크포인트나 최근 체크포인트 등을 자동으로 관리할 수 있으며, 이를 통해 나중에 쉽게 로드할 수 있음
checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(model, dataloader)

# 나중에 최상의 체크포인트 로드
best_model_path = checkpoint_callback.best_model_path
model = MyModel.load_from_checkpoint(best_model_path)
  1. 수동으로 torch.load 사용
  • PyTorch Lightning의 체크포인트는 사전 형식으로 저장되므로, 필요한 경우 PyTorch의 기본 torch.load 함수를 사용하여 수동으로 로드할 수도 있음
checkpoint = torch.load("path/to/checkpoint.ckpt")
model.load_state_dict(checkpoint['state_dict'])

재차 chatgpt에게 물어본 결과, pl 1.x 버전에서 동작하는 방법들이라고 언급했기에 2.x에서의 사용 여부와 docs를 추가적으로 살펴보며 분석하고자 한다.

@jinmang2
Copy link
Owner Author

우선, 3번의 resume_from_checkpoint의 경우, lightning의 issue 9501에서 deprecated되었고 fit의 ckpt_path를 통해 제어하도록 수정된 것을 확인할 수 있었다.

Issue 9006에서 Trainer의 arguments를 최소화하고자 수정을 한 것으로 보인다.

위를 다시 정리하면, 1. naive하게 torch.load를 사용해서 직접 custom 2. LightningModuleload_from_checkpoint 메서드 활용 3. pl.Trainerfit method의 ckpt_path 인자를 넣어서. 이렇게 총 세 가지의 방식을 활용할 수 있다.

본 레포는 lightning에 대해 분석하고 활용할 수 있는 능력을 기르는 것에도 목적이 있기에 torch.load를 활용해서 직접 구현하는 방법은 지양하고 2와 3에 대해 source code 동작을 뜯어보고자 한다.

@jinmang2
Copy link
Owner Author

jinmang2 commented Aug 18, 2023

load_from_checkpoint

pl.LightningModuleload_from_checkpoint classmethod는 매우 단순하다.

# lightning.pytorch.core.module.py
...
from typing import cast
...
from typing_extensions import Self
...
from lightning.pytorch.core.saving import _load_from_checkpoint
...

class LightningModule(
    _DeviceDtypeModuleMixin,
    HyperparametersMixin,
    ModelHooks,
    DataHooks,
    CheckpointHooks,
    Module,
):
    ...
    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        hparams_file: Optional[_PATH] = None,
        strict: bool = True,
        **kwargs: Any,
    ) -> Self:
        loaded = _load_from_checkpoint(
            cls,
            checkpoint_path,
            map_location,
            hparams_file,
            strict,
            **kwargs,
        )
        return cast(Self, loaded)

_load_from_checkpoint는 아래와 같다.

# lightning.pytorch.core.saving.py
def _load_from_checkpoint(
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
    checkpoint_path: Union[_PATH, IO],
    map_location: _MAP_LOCATION_TYPE = None,
    hparams_file: Optional[_PATH] = None,
    strict: Optional[bool] = None,
    **kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
    map_location = map_location or _default_map_location
    with pl_legacy_patch():
        checkpoint = pl_load(checkpoint_path, map_location=map_location)

    # convert legacy checkpoints to the new format
    checkpoint = _pl_migrate_checkpoint(
        checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
    )

    if hparams_file is not None:
        extension = str(hparams_file).split(".")[-1]
        if extension.lower() == "csv":
            hparams = load_hparams_from_tags_csv(hparams_file)
        elif extension.lower() in ("yml", "yaml"):
            hparams = load_hparams_from_yaml(hparams_file)
        else:
            raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")

        # overwrite hparams by the given file
        checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams

    # TODO: make this a migration:
    # for past checkpoint need to add the new key
    checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
    # override the hparams with values that were passed in
    checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)

    if issubclass(cls, pl.LightningDataModule):
        return _load_state(cls, checkpoint, **kwargs)
    if issubclass(cls, pl.LightningModule):
        model = _load_state(cls, checkpoint, strict=strict, **kwargs)
        state_dict = checkpoint["state_dict"]
        if not state_dict:
            rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
            return model

        device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
        assert isinstance(model, pl.LightningModule)
        return model.to(device)

    raise NotImplementedError(f"Unsupported {cls}")

코드가 길어보이지만 실상은 첫 줄의 pl_load함수로 checkpoint를 호출하고(pl_legacy_patch로 이전 버전에서 저장된 checkpoint를 현재 버전에서도 load할 수 있도록하는 patch를 적용) legacy ckpt일 경우를 대비해 new format으로 바꿔주는 _pl_migrate_checkpoint 함수를 적용한 다음, 만일 hparams_file이 입력으로 들어왔다면 이를 읽고 checkpoint에 추가, default key를 설정해주고 _load_state 함수를 사용하여 datamodule/lightningmodule을 instantiate하고 이를 반환해주는 단순한 코드이다.

즉, 첫 줄의 pl_load가 제일 중요한 부분이다. (실제로 checkpoint가 불러져오는 부분임)

하나씩 분석하자. 우선 pl_load에서 어떻게 checkpoint file을 불러오는지 확인하자

# lightning.fabric.utilities.cloud_io.py
def _load(
    path_or_url: Union[IO, _PATH],
    map_location: _MAP_LOCATION_TYPE = None,
) -> Any:
    """Loads a checkpoint.

    Args:
        path_or_url: Path or URL of the checkpoint.
        map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.

    """
    if not isinstance(path_or_url, (str, Path)):
        # any sort of BytesIO or similar
        return torch.load(
            path_or_url,
            map_location=map_location,  # type: ignore[arg-type] # upstream annotation is not correct
        )
    if str(path_or_url).startswith("http"):
        return torch.hub.load_state_dict_from_url(
            str(path_or_url),
            map_location=map_location,  # type: ignore[arg-type]
        )
    fs = get_filesystem(path_or_url)
    with fs.open(path_or_url, "rb") as f:
        return torch.load(f, map_location=map_location)  # type: ignore[arg-type]

torch.load는 f로 file-like object(read, readline, tell, seek method가 구현되어있는) 혹은 str/os.PathLike object(file name을 포함하는)을 받는다. pl_load는 file-like object인 경우 바로 torch.load에 태워주고 hub에서 받아오도록 path가 http로 시작하면 torch.hub.load_state_dict_from_url 메서드로 state_dict를 받아온다. 마지막으로 fsspec의 url_to_fs 함수를 사용하여 protocol을 판별하고 이를 context manager 역할을 수행할 수 있는 filesystem class로 반환하고 해당 path의 file을 열고 torch.load로 state_dict를 호출해준다. pl_legact_patch는 위에서 언급했듯이 old checkpoint에는 존재하나 현재는 사용하지 않는 legacy artifacts를 잠시 register하고 load가 끝난 후에는 해제해주는 context manager이다.

언급했듯 _loadtorch.load를 수행하는 그 이상도 이하도 아니며 pl_legacy_checkpoint를 활용하여 이전 버전의 checkpoint도 문제없이 호출할 수 있도록 코드를 수행했지만 현재 사용하는 version의 format으로 맞춰줄 필요가 있다. 이를 migrate_checkpoint 메서드가 수행해준다. (lightning.pytorch.utilities.migration.utils.py 참고)

torch.load로 in-memory에 checkpoint 파일을 부르고 LightningModule로 checkpoint 중 state_dict던지 필요한 부분을 불러야할 필요가 있다. 이를 lightning에서는 _load_state 함수에서 처리하고 있고 아래와 같이 동작한다.

# lightning.pytorch.core.saving.py
def _load_state(
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
    checkpoint: Dict[str, Any],
    strict: Optional[bool] = None,
    **cls_kwargs_new: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
    cls_spec = inspect.getfullargspec(cls.__init__)
    cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()

    self_var, args_var, kwargs_var = parse_class_init_keys(cls)
    drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
    cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))

    cls_kwargs_loaded = {}
    # pass in the values we saved automatically
    if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
        if issubclass(cls, pl.LightningModule):
            # TODO: make this a migration:
            # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
            for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
                cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

        # 2. Try to restore model hparams from checkpoint using the new key
        cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))

        # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
        cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))

        # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
        args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
        if args_name and args_name in cls_init_args_name:
            cls_kwargs_loaded = {args_name: cls_kwargs_loaded}

    _cls_kwargs = {}
    _cls_kwargs.update(cls_kwargs_loaded)
    _cls_kwargs.update(cls_kwargs_new)

    if not cls_spec.varkw:
        # filter kwargs according to class init unless it allows any argument via kwargs
        _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}

    obj = cls(**_cls_kwargs)

    if isinstance(obj, pl.LightningModule):
        # give model a chance to load something
        obj.on_load_checkpoint(checkpoint)

    if isinstance(obj, pl.LightningDataModule):
        if obj.__class__.__qualname__ in checkpoint:
            obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
        return obj

    # load the state_dict on the model automatically
    assert strict is not None
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)

    if not strict:
        if keys.missing_keys:
            rank_zero_warn(
                f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
            )
        if keys.unexpected_keys:
            rank_zero_warn(
                f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
            )

    return obj

LightningModule 혹은 LightningDataModule의 생성자의 arguments들을 파이썬의 기본 내장 inspect 라이브러리를 활용하여 cls_init_args_name을 얻는다.

LightningModuleHyperparametersMixin을 상속받아 save_hyperparameters method를 사용할 수 있다. 생성자에 주어진 args/kwargs를 self.hparams에 등록하는데 이는 자동으로 checkpoint에 저장되고 CHECKPOINT_HYPER_PARAMS_{}.format(KEY|TYPE|NAME)을 통해 접근할 수 있다. 중간 부분은 해당 부분을 불러오는 부분이다. (상세 동작 추후 확인). 이는 LightningModule을 호출할 때 필요한 인자이기 때문에 _cls_kwargs에 할당한 다음 object를 instantiate한다. LightningDataModule의 경우 해당 객체의 load_state_dict 메서드를 수행하고 끝나고 현재 관심사가 아니기 때문에 LightningModule만 살펴본다.

LightningModule object를 생성하고 _load_state에선 아래 두 method를 수행한다.

# give model a chance to load something
obj.on_load_checkpoint(checkpoint)

# load the state_dict on the model automatically
assert strict is not None
keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)

on_load_checkpoint는 user custom 함수로 checkpoint 파일에서 받아올 부분을 직접 작성하여 어떤 key를 받아와서 객체에 저장할 지 결정할 수 있다.
load_state_dictLightningModulenn.Module을 상속받기 때문에 torch의 Module의 load_state_dict method를 그대로 수행해서 model에 넘겨준다.

최종적으로 _load_from_checkpoint 함수에서는 checkpoint의 state_dict의 device로 lightning module의 device를 할당해주고 반환한다.

@jinmang2
Copy link
Owner Author

jinmang2 commented Aug 22, 2023

lightning.pytorch.Trainer(...).fit(model=model, ckpt_path={YOUR_CKPT_PATH})

pl.Trainer.fit의 동작은 추상화되어있어 참으로 단순해보인다.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def fit(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        model = _maybe_unwrap_optimized(model)
        self.strategy._lightning_module = model
        _verify_strategy_supports_compile(model, self.strategy)
        call._call_and_handle_interrupt(
            self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        )

앞의 세 줄은 model/strategy를 wrapping하고 verify하는 부분이고 실제로 중요한 부분은 마지막 줄이다.

lightning.pytorch.trainer.call.py_call_and_handle_interrupt 함수는 pl.Trainer의 main entry point인 fit, validate, test, predict 함수에 대한 error handling을 위해 설계된 함수이다. trainer.strategy.launcher가 있으면 해당 launcher를 사용하여 trainer_fn을 실행하고 그렇지 않으면 trainer_fn을 직접 호출한다. 에러는 아래의 3 종류에 따라 처리한다.

  • _TunerExitException: teardown 호출 후 trainer status를 FINISHED로 설정
  • KeyboardInterrupt: 사용자가 process를 중단하려고 하면 graceful shutdown 시도. 이후 trainer status를 INTERRUPTED로 설정
    • trainer의 on_exception을 호출하여 사용자가 중단한 경우에 설정을 저장하거나 등을 수행이 가능한 것으로 보임
  • BaseException: trainer status를 INTERRUPTED로 설정하고 logger로 failed 설정

즉, 중요한 부분은 self._fit_impl.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def _fit_impl(
        self,
        model: "pl.LightningModule",
        train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None,
    ) -> None:
        log.debug(f"{self.__class__.__name__}: trainer fit stage")

        self.state.fn = TrainerFn.FITTING
        self.state.status = TrainerStatus.RUNNING
        self.training = True

        # if a datamodule comes in as the second arg, then fix it for the user
        if isinstance(train_dataloaders, LightningDataModule):
            datamodule = train_dataloaders
            train_dataloaders = None
        # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
        if (train_dataloaders is not None or val_dataloaders is not None) and datamodule is not None:
            raise MisconfigurationException(
                "You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`"
            )

        # links data to the trainer
        self._data_connector.attach_data(
            model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
        )

        ckpt_path = self._checkpoint_connector._select_ckpt_path(
            self.state.fn,
            ckpt_path,
            model_provided=True,
            model_connected=self.lightning_module is not None,
        )
        self._run(model, ckpt_path=ckpt_path)

        assert self.state.stopped
        self.training = False
        return

코드를 보면 training setup을 수행하고 self._run을 통해 실질적인 동작을 수행하는 것을 확인할 수 있다. _run을 확인하기 전에 _select_ckpt_path가 어떤 동작을 수행하는지 확인해보자.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    @_defaults_from_env_vars
    def __init__(
        self,
        ...
    ) -> None:
        ...
        self._checkpoint_connector = _CheckpointConnector(self)
        ...
    ...
    @property
    def ckpt_path(self) -> Optional[_PATH]:
        """Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`,
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`,
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or
        :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. ``None`` otherwise."""
        return self._checkpoint_connector._ckpt_path

    @ckpt_path.setter
    def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
        """Allows you to manage which checkpoint is loaded statefully.

        .. code-block:: python

            trainer = Trainer()
            trainer.ckpt_path = "my/checkpoint/file.ckpt"
            trainer.fit(model)
            ...

            # you will be in charge of resetting this
            trainer.ckpt_path = None
            trainer.test(model)

        """
        self._checkpoint_connector._ckpt_path = ckpt_path
        self._checkpoint_connector._user_managed = bool(ckpt_path)
...
# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def _select_ckpt_path(
        self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
    ) -> Optional[_PATH]:
        """Called by the ``Trainer`` to select the checkpoint path source."""
        if self._user_managed:
            if ckpt_path:
                rank_zero_warn(
                    f"`trainer.ckpt_path = {self._ckpt_path!r}` was called but then you"
                    f" passed `trainer.fit(ckpt_path={ckpt_path!r})`. The latter will be loaded."
                )
                # reset the previous path
                self._ckpt_path = None
                self._user_managed = False
                ckpt_path = self._parse_ckpt_path(
                    state_fn,
                    ckpt_path,
                    model_provided=model_provided,
                    model_connected=model_connected,
                )
            else:
                ckpt_path = self._ckpt_path
        else:
            ckpt_path = self._parse_ckpt_path(
                state_fn,
                ckpt_path,
                model_provided=model_provided,
                model_connected=model_connected,
            )
        return ckpt_path

    def _parse_ckpt_path(
        self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool
    ) -> Optional[_PATH]:
        """Converts the ``ckpt_path`` special values into an actual filepath, depending on the trainer
        configuration."""
        if ckpt_path is None and SLURMEnvironment.detect() and self._hpc_resume_path is not None:
            ckpt_path = "hpc"

        from lightning.pytorch.callbacks.on_exception_checkpoint import OnExceptionCheckpoint

        ft_checkpoints = [cb for cb in self.trainer.callbacks if isinstance(cb, OnExceptionCheckpoint)]
        fn = state_fn.value
        if ckpt_path is None and ft_checkpoints and self.trainer.state.fn == TrainerFn.FITTING:
            ckpt_path = "last"
            rank_zero_warn(
                f"`.{fn}(ckpt_path=None)` was called without a model."
                " The last model of the previous `fit` call will be used."
                f" You can pass `{fn}(ckpt_path='best')` to use the best model or"
                f" `{fn}(ckpt_path='last')` to use the last model."
                " If you pass a value, this warning will be silenced."
            )

        if model_provided and ckpt_path is None:
            # use passed model to function without loading weights
            return None

        if model_connected and ckpt_path is None:
            ckpt_path = "best"
            ft_tip = (
                " There is also an on-exception checkpoint available, however it is used by default only when fitting."
                if ft_checkpoints
                else ""
            )
            rank_zero_warn(
                f"`.{fn}(ckpt_path=None)` was called without a model."
                " The best model of the previous `fit` call will be used."
                + ft_tip
                + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
                f" `.{fn}(ckpt_path='last')` to use the last model."
                " If you pass a value, this warning will be silenced."
            )

        if ckpt_path == "best":
            if len(self.trainer.checkpoint_callbacks) > 1:
                rank_zero_warn(
                    f'`.{fn}(ckpt_path="best")` is called with Trainer configured with multiple `ModelCheckpoint`'
                    " callbacks. It will use the best checkpoint path from first checkpoint callback."
                )

            if not self.trainer.checkpoint_callback:
                raise ValueError(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.')

            has_best_model_path = self.trainer.checkpoint_callback.best_model_path
            if hasattr(self.trainer.checkpoint_callback, "best_model_path") and not has_best_model_path:
                if self.trainer.fast_dev_run:
                    raise ValueError(
                        f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
                        f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
                    )
                raise ValueError(
                    f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
                )
            # load best weights
            ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None)

        elif ckpt_path == "last":
            candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints}
            for callback in self.trainer.checkpoint_callbacks:
                if isinstance(callback, ModelCheckpoint):
                    candidates |= callback._find_last_checkpoints(self.trainer)
            candidates_fs = {path: get_filesystem(path) for path in candidates if path}
            candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
            if not candidates_ts:
                # not an error so it can be set and forget before the first `fit` run
                rank_zero_warn(
                    f'.{fn}(ckpt_path="last") is set, but there is no last checkpoint available.'
                    " No checkpoint will be loaded."
                )
                return None
            ckpt_path = max(candidates_ts, key=candidates_ts.get)  # type: ignore[arg-type]

        elif ckpt_path == "hpc":
            if not self._hpc_resume_path:
                raise ValueError(
                    f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.'
                    " Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`"
                )
            ckpt_path = self._hpc_resume_path

        if not ckpt_path:
            raise ValueError(
                f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
                f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
            )
        return ckpt_path

trainer._checkpoint_connector_user_managed attr이 무엇인지 궁금했는데 pl.Trainer에서 용도를 알 수 있었다. trainer.ckpt_path = {YOUR_CKPT_PATH}와 같이 할당할 경우 user가 직접 ckpt path를 관리한다는 의미로 _user_managed가 True로 설정됨과 동시에 _checkpoint_connector의 ckpt_path에 할당한다. _user_managed 옵션이 True인데도 self.fit 함수에 ckpt_path가 입력됐을 경우 원 상태로 복구하고 warning을 띄우고 _parse_ckpt_path method를 실행한다. 아닐 경우에도 _parse_ckpt_path를 실행하여 ckpt_path를 select하고 만일 trainer.fit에 ckpt_path가 입력되지 않았고 _user_managed 옵션도 False일 경우 trainer._ckpt_path를 select한다.

_parse_ckpt_pathckpt_path를 trainer configuration에 depending된 actual filepath로 변환하는 CheckpointConnector의 method. 만일 특정 ckpt_path를 입력했다면 아무런 동작도 하지않고 그대로 반환해주며 ckpt_path가 None이면 trainer의 status, callbacks 등에 따라 hpc, last, best를 할당하고 이 후 각자 설정에 맞는 checkpoint path를 가져온다.

trainer._run은 짧게 필요한 부분만 담고자 한다.

# lightning.pytorch.trainer.trainer.py
class Trainer(...):
    ...
    def _run(
        self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None
    ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
        ...
        # ----------------------------
        # SET UP THE TRAINER
        # ----------------------------
        ...
        # check if we should delay restoring checkpoint till later
        if not self.strategy.restore_checkpoint_after_setup:
            log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
            self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
        ... setup ...
        if self.strategy.restore_checkpoint_after_setup:
            log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
            self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
        ...
        return results

trainer.strategy.restore_checkpoint_after_setup attr에 따라 setup 전에 restore checkpoint를 수행하냐 혹은 후에 하느냐 차이.

# lightning.pytorch.trainer.connectors.checkpoint_connector.py
class _CheckpointConnector:
    ...
    def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
        """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

        1. from HPC weights if `checkpoint_path` is ``None`` and on SLURM or passed keyword `"hpc"`.
        2. from fault-tolerant auto-saved checkpoint if found
        3. from `checkpoint_path` file if provided
        4. don't restore

        """
        self._ckpt_path = checkpoint_path
        if not checkpoint_path:
            log.debug("`checkpoint_path` not specified. Skipping checkpoint loading.")
            return

        rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
        with pl_legacy_patch():
            loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
        self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path)
    ...
    def restore_model(self) -> None:
        """Restores a model's weights from a PyTorch Lightning checkpoint.

        Hooks are called first to give the LightningModule a chance to modify the contents, then finally the model gets
        updated with the loaded weights.

        """
        if not self._loaded_checkpoint:
            return

        trainer = self.trainer
        # hook: give user access to checkpoint if needed.
        call._call_lightning_module_hook(trainer, "on_load_checkpoint", self._loaded_checkpoint)

        # restore model state_dict
        trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
    ...
    def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None) -> None:
        # restore modules after setup
        self.resume_start(checkpoint_path)
        self.restore_model()
        self.restore_datamodule()
        if self.trainer.state.fn == TrainerFn.FITTING:
            # restore callback states
            self.restore_callbacks()

# lightning.pytorch.strategies.strategy.py
class Strategy(ABC):
    ...
    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
        torch.cuda.empty_cache()
        return self.checkpoint_io.load_checkpoint(checkpoint_path)

    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        assert self.lightning_module is not None
        self.lightning_module.load_state_dict(checkpoint["state_dict"])

동작은 앞서 load_from_checkpoint에서 본 과정과 동일하나 차이점은 trainer.strategy가 가지고 있는 load_checkpointload_model_state_dict method를 사용한다는 점이다. 이는 예를 들어 deepspeed strategy의 경우 아래와 같이 override하기에 이렇게 수정한 것으로 보인다.

# lightning.pytorch.strategies.deepspeed.py
class DeepSpeedStrategy(DDPStrategy):
    strategy_name = "deepspeed"
    DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH"
    ...
    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
        if self.load_full_weights and self.zero_stage_3:
            # Broadcast to ensure we load from the rank 0 checkpoint
            # This doesn't have to be the case when using deepspeed sharded checkpointing
            checkpoint_path = self.broadcast(checkpoint_path)
            return super().load_checkpoint(checkpoint_path)

        _validate_checkpoint_directory(checkpoint_path)

        # Rely on deepspeed to load the checkpoint and necessary information
        assert self.lightning_module is not None

        from lightning.pytorch.trainer.states import TrainerFn

        is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING

        _, client_state = self.deepspeed_engine.load_checkpoint(
            checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=False
        )
        if client_state is None:
            raise MisconfigurationException(
                "DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
                "or a single checkpoint file with `Trainer(strategy=DeepSpeedStrategy(load_full_weights=True))`."
            )
        return client_state
    ...
    def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
        # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
        if self.load_full_weights and self.zero_stage_3:
            self.model_to_device()
            self._restore_zero_state(checkpoint)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant