Skip to content
This repository has been archived by the owner on Apr 16, 2024. It is now read-only.

S-aiueo32/hiraishin

Repository files navigation

Hiraishin

A thin PyTorch-Lightning wrapper for building configuration-based DL pipelines with Hydra.

Dependencies

  • PyTorch Lightning
  • Hydra
  • Pydantic
  • etc.

Installation

$ pip install -U hiraishin

Basic workflow

1. Model initialization with type annotations

Define a model class that has training components with type annotations.

import torch.nn as nn
import torch.optim as optim

from hiraishin.models import BaseModel


class ToyModel(BaseModel):

    net: nn.Linear
    criterion: nn.CrossEntropyLoss
    optimizer: optim.Adam
    scheduler: optim.lr_schedulers.ExponentialLR

    def __init__(self, config: DictConfig) -> None:
        super().__init__(config)

Modules with the following prefixes are instantiated by their own role-specific logic.

  • net
  • criterion
  • optimizer
  • scheduler

The same notation can be used to define components other than the learning components listed above (e.g., tokenizers). It is also possible to define built-in type constants that are YAML serializable.

class ToyModel(BaseModel):

    net: nn.Linear
    criterion: nn.CrossEntropyLoss
    optimizer: optim.Adam
    scheduler: optim.lr_schedulers.ExponentialLR

    # additional components and constants
    tokenizer: MyTokenizer
    n_classes: int

    def __init__(self, config: DictConfig) -> None:
        super().__init__(config)

2. Configuration file generation

Hiraishin provides a CLI command that automatically generates a configuration file based on type annotations.

For example, if ToyModel is defined in models.py (i.e., from models import ToyModel can be executed in the code), then the following command will generate the configuration file automatically.

$ hiraishin generate model.ToyModel --output_dir config/model
The config has been generated! --> config/model/ToyModel.yaml

Let's take a look at the generated file.

_target_: models.ToyModel
_recursive_: false
config:

  networks:
    net:
      args:
        _target_: torch.nn.Linear
        out_features: ???
        in_features: ???
      weights:
        initializer: null
        path: null

  losses:
    criterion:
      args:
        _target_: torch.nn.CrossEntropyLoss
      weight: 1.0

  optimizers:
    optimizer:
      args:
        _target_: torch.optim.Adam
      params:
      - ???
      scheduler:
        args:
          _target_: torch.optim.lr_scheduler.ExponentialLR
          gamma: ???
        interval: epoch
        frequency: 1
        strict: true
        monitor: null

  tokenizer:
    _target_: MyTokenizer
  n_classes: ???

First of all, it is compliant with the instantiation by hydra.utils.instantiate.

The positional arguments are filled with ??? that indicates mandatory parameters. They should be overridden by the values you want to set.

3. Training routines definition

The rest of model definition is only defining your training routine along with the style of PyTorch Lightning.

class ToyModel(BaseModel):
    
    ...

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def training_step(self, batch, *args, **kwargs) -> torch.Tensor:
        x, target = batch
        pred = self.forward(x)
        loss = self.criterion(pred, target)
        self.log('loss/train', loss)
        return loss

4. Model Instantiation

The defined model can be instantiated from configuration file. Let's train your models!

from hydra.utils import inatantiate
from omegeconf import OmegaConf


def app():
    ...

    config = OmegaConf.load('config/model/toy.yaml')
    model = inatantiate(config)

    print(model)
    # ToyModel(
    #     (net): Linear(in_features=1, out_features=1, bias=True)
    #     (criterion): CrossEntropyLoss()
    # )

    trainer.fit(model, ...)

5. Model loading

You can easily load trained models by using the checkpoints generated by PyTorch Lightning's standard features. Let's test your models!

from hiraishin.utils import load_from_checkpoint

model = load_from_checkpoint('path/to/model.ckpt')
print(model)
# ToyModel(
#     (net): Linear(in_features=1, out_features=1, bias=True)
#     (criterion): CrossEntropyLoss()
# )

License

Hiraishin is licensed under the Apache License, Version 2.0. See LICENSE for the full license text.