Skip to content

Commit

Permalink
Add allegro.ai TRAINS experiment manager support
Browse files Browse the repository at this point in the history
  • Loading branch information
bmartinn committed Feb 25, 2020
1 parent c56ee8b commit 3a868d8
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def setup(app):
MOCK_REQUIRE_PACKAGES.append(pkg.rstrip())

# TODO: better parse from package since the import name and package name may differ
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune']
MOCK_MANUAL_PACKAGES = ['torch', 'torchvision', 'test_tube', 'mlflow', 'comet_ml', 'wandb', 'neptune', 'trains']
autodoc_mock_imports = MOCK_REQUIRE_PACKAGES + MOCK_MANUAL_PACKAGES
# for mod_name in MOCK_REQUIRE_PACKAGES:
# sys.modules[mod_name] = mock.Mock()
Expand Down
28 changes: 28 additions & 0 deletions docs/source/experiment_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,34 @@ The Neptune.ai is available anywhere in your LightningModule
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
allegro.ai TRAINS
^^^^^^^^^^^^^^^^^

`allegro.ai <https://github.com/allegroai/trains/>`_ is a third-party logger.
To use TRAINS as your logger do the following.

.. note:: See: :ref:`trains` docs.

.. code-block:: python
from pytorch_lightning.loggers import TrainsLogger
trains_logger = TrainsLogger(
project_name="examples",
task_name="pytorch lightning test"
)
trainer = Trainer(logger=trains_logger)
The TrainsLogger is available anywhere in your LightningModule

.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
some_img = fake_image()
self.logger.log_image('debug', 'generated_image_0', some_img, 0)
Tensorboard
^^^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion docs/source/experiment_reporting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ want to log using this trainer flag.
Log metrics
^^^^^^^^^^^

To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, TRAINS, etc...)

1. Training_end, validation_end, test_end will all log anything in the "log" key of the return dict.

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ dependencies:
- comet_ml>=1.0.56
- wandb>=0.8.21
- neptune-client>=0.4.4
- trains>=0.13.3
6 changes: 6 additions & 0 deletions pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,9 @@ def any_lightning_module_function_or_hook(...):
__all__.append('WandbLogger')
except ImportError:
pass

try:
from .trains import TrainsLogger
__all__.append('TrainsLogger')
except ImportError:
pass
201 changes: 201 additions & 0 deletions pytorch_lightning/loggers/trains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
Log using `allegro.ai TRAINS <https://github.com/allegroai/trains>'_
.. code-block:: python
from pytorch_lightning.loggers import TrainsLogger
trains_logger = TrainsLogger(
project_name="pytorch lightning",
task_name="default",
)
trainer = Trainer(logger=trains_logger)
Use the logger anywhere in you LightningModule as follows:
.. code-block:: python
def train_step(...):
# example
self.logger.experiment.whatever_trains_supports(...)
def any_lightning_module_function_or_hook(...):
self.logger.experiment.whatever_trains_supports(...)
"""

from logging import getLogger
import torch

try:
import trains
except ImportError:
raise ImportError('Missing TRAINS package.')

from .base import LightningLoggerBase, rank_zero_only

logger = getLogger(__name__)


class TrainsLogger(LightningLoggerBase):
def __init__(self, project_name=None, task_name=None, **kwargs):
r"""
Logs using TRAINS
Args:
project_name (str): The name of the experiment's project
task_name (str): The name of the experiment
"""
super().__init__()
self._trains = trains.Task.init(project_name=project_name, task_name=task_name, **kwargs)

@property
def experiment(self):
r"""
Actual TRAINS object. To use TRAINS features do the following.
Example::
self.logger.experiment.some_trains_function()
"""
return self._trains

@property
def id(self):
if not self._trains:
return None
return self._trains.id

@rank_zero_only
def log_hyperparams(self, params):
if not self._trains:
return None
if not params:
return
if isinstance(params, dict):
self._trains.connect(params)
else:
self._trains.connect(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step=None):
if not self._trains:
return None
if not step:
step = self._trains.get_last_iteration()
for k, v in metrics.items():
if isinstance(v, str):
logger.warning("Discarding metric with string value {}={}".format(k, v))
continue
if isinstance(v, torch.Tensor):
v = v.item()
parts = k.split('/')
if len(parts) <= 1:
series = title = k
else:
title = parts[0]
series = parts[1:]
self._trains.get_logger().report_scalar(title=title, series=series, value=v, iteration=step)

@rank_zero_only
def log_metric(self, title, series, value, step=None):
"""Log metrics (numeric values) in TRAINS experiments
:param str title: The title of the graph to log, e.g. loss, accuracy.
:param str series: The series name in the graph, e.g. classification, localization
:param float value: The value to log
:param int|None step: Step number at which the metrics should be recorded
"""
if not self._trains:
return None
if not step:
step = self._trains.get_last_iteration()
if isinstance(value, torch.Tensor):
value = value.item()
self._trains.get_logger().report_scalar(title=title, series=series, value=value, iteration=step)

@rank_zero_only
def log_text(self, text):
"""Log console text data in TRAINS experiment
:param str text: The value of the log (data-point).
"""
if not self._trains:
return None
self._trains.get_logger().report_text(text)

@rank_zero_only
def log_image(self, title, series, image, step=None):
"""Log Debug image in TRAINS experiment
:param str title: The title of the debug image, i.e. "failed", "passed".
:param str series: The series name of the debug image, i.e. "Image 0", "Image 1".
:param str|Numpy|PIL.Image image: Debug image to log.
Can be one of the following types: Numpy, PIL image, path to image file (str)
:param int|None step: Step number at which the metrics should be recorded
"""
if not self._trains:
return None
if not step:
step = self._trains.get_last_iteration()
if isinstance(image, str):
self._trains.get_logger().report_image(title=title, series=series, local_path=image, iteration=step)
else:
self._trains.get_logger().report_image(title=title, series=series, image=image, iteration=step)

@rank_zero_only
def log_artifact(self, name, artifact, metadata=None, delete_after_upload=False):
"""Save an artifact (file/object) in TRAINS experiment storage.
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
:param object artifact: Artifact object to upload. Currently supports:
- string / pathlib2.Path are treated as path to artifact file to upload
If wildcard or a folder is passed, zip file containing the local files will be created and uploaded
- dict will be stored as .json file and uploaded
- pandas.DataFrame will be stored as .csv.gz (compressed CSV file) and uploaded
- numpy.ndarray will be stored as .npz and uploaded
- PIL.Image will be stored to .png file and uploaded
:param dict metadata: Simple key/value dictionary to store on the artifact
:param bool delete_after_upload: If True local artifact will be deleted
(only applies if artifact_object is a local file)
"""
if not self._trains:
return None
self._trains.upload_artifact(name=name, artifact_object=artifact,
metadata=metadata, delete_after_upload=delete_after_upload)

def save(self):
pass

@rank_zero_only
def finalize(self, status):
if not self._trains:
return None
self._trains.close()
self._trains = None

@property
def name(self):
if not self._trains:
return None
return self._trains.name

@property
def version(self):
if not self._trains:
return None
return self._trains.id

def __getstate__(self):
if not self._trains:
return None
return self._trains.id

def __setstate__(self, state):
self._rank = 0
self._trains = None
if state:
self._trains = trains.Task.get_task(task_id=state)
5 changes: 5 additions & 0 deletions pytorch_lightning/logging/trains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
.. warning:: `logging` package has been renamed to `loggers` since v0.6.1 and will be removed in v0.8.0
"""

from pytorch_lightning.loggers import trains # noqa: F403
3 changes: 2 additions & 1 deletion requirements-extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ neptune-client>=0.4.4
comet-ml>=1.0.56
mlflow>=1.0.0
test_tube>=0.7.5
wandb>=0.8.21
wandb>=0.8.21
trains>=0.13.3
45 changes: 44 additions & 1 deletion tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
MLFlowLogger,
CometLogger,
WandbLogger,
NeptuneLogger
NeptuneLogger,
TrainsLogger
)
from tests.models import LightningTestModel

Expand Down Expand Up @@ -236,6 +237,48 @@ def test_neptune_pickle(tmpdir):
trainer2.logger.log_metrics({"acc": 1.0})


def test_trains_logger(tmpdir):
"""Verify that basic functionality of TRAINS logger works."""
tutils.reset_seed()

hparams = tutils.get_hparams()
model = LightningTestModel(hparams)
logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test")

trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)

print('result finished')
assert result == 1, "Training failed"


def test_trains_pickle(tmpdir):
"""Verify that pickling trainer with TRAINS logger works."""
tutils.reset_seed()

# hparams = tutils.get_hparams()
# model = LightningTestModel(hparams)

logger = TrainsLogger(project_name="examples", task_name="pytorch lightning test")

trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
logger=logger
)

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})


def test_tensorboard_logger(tmpdir):
"""Verify that basic functionality of Tensorboard logger works."""

Expand Down

0 comments on commit 3a868d8

Please sign in to comment.