Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integration-torchtune #653

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions swanlab/integration/torchtune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from torchtune.utils._distributed import get_world_size_and_rank
from torchtune.utils.metric_logging import MetricLoggerInterface
from typing import Mapping, Optional, Union

from numpy import ndarray
from omegaconf import DictConfig, OmegaConf
from torch import Tensor

Scalar = Union[Tensor, ndarray, int, float]


class SwanLabLogger(MetricLoggerInterface):
"""
Example:
In torchtune config:
>>> # Logging
>>> metric_logger:
>>> _component_: swanlab.integration.huggingface.SwanLabLogger
>>> project: "gemma-lora-finetune"
>>> experiment_name: "gemma-2b"
>>> log_dir: ${output_dir}

Raises:
ImportError: If ``swanlab`` package is not installed.

Note:
This logger requires the swanlab package to be installed.
You can install it with `pip install swanlab`.
In order to use the logger, you need to login to your SwanLab account.
You can do this by running `swanlab login` in your terminal.
"""

def __init__(
self,
project: str = "torchtune",
workspace: Optional[str] = None,
experiment_name: Optional[str] = None,
description: Optional[str] = None,
mode: Optional[str] = None,
log_dir: Optional[str] = None,
**kwargs,
):
try:
import swanlab
except ImportError as e:
raise ImportError(
"``swanlab`` package not found. Please install swanlab using `pip install swanlab` to use SwanLabLogger."
) from e
self._swanlab = swanlab

# Use dir if specified, otherwise use log_dir.
self.log_dir = kwargs.pop("dir", log_dir)

_, self.rank = get_world_size_and_rank()

if self._swanlab.get_run() is None and self.rank == 0:
run = self._swanlab.init(
project=project,
workspace=workspace,
name=experiment_name,
description=description,
mode=mode,
logdir=self.log_dir,
**kwargs,
)

self.config_allow_val_change = kwargs.get("allow_val_change", False)

def log_config(self, config: DictConfig) -> None:
if self._swanlab.get_run():
resolved = OmegaConf.to_container(config, resolve=True)
self._swanlab.config.update(resolved, allow_val_change=self.config_allow_val_change)

def log(self, name: str, data: Scalar, step: int) -> None:
if self._swanlab.get_run():
self._swanlab.log({name: data}, step=step)

def log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
if self._swanlab.get_run():
self._swanlab.log({**payload}, step=step)

def __del__(self) -> None:
if self._swanlab.get_run():
self._swanlab.finish()

def close(self) -> None:
if self._swanlab.get_run():
self._swanlab.finish()