Skip to content

Commit

Permalink
[air] New train.Checkpoint API: Update Train tests + examples (ba…
Browse files Browse the repository at this point in the history
…tch 3 - `test_*_trainer`) (ray-project#38759)

This PR mainly updates the trainer unit tests.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
justinvyu authored and Victor committed Oct 11, 2023
1 parent dcfe6ba commit 9ec4646
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 274 deletions.
14 changes: 7 additions & 7 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ py_test(
size = "medium",
main = "examples/tf/tensorflow_regression_example.py",
srcs = ["examples/tf/tensorflow_regression_example.py"],
tags = ["team:ml", "exclusive", "no_new_storage"],
tags = ["team:ml", "exclusive", "new_storage"],
deps = [":train_lib"],
args = ["--smoke-test"]
)
Expand Down Expand Up @@ -240,7 +240,7 @@ py_test(
name = "test_base_trainer",
size = "medium",
srcs = ["tests/test_base_trainer.py"],
tags = ["team:ml", "exclusive", "ray_air", "no_new_storage"],
tags = ["team:ml", "exclusive", "ray_air", "new_storage"],
deps = [":train_lib", ":conftest"]
)

Expand All @@ -256,23 +256,23 @@ py_test(
name = "test_checkpoint_manager",
size = "small",
srcs = ["tests/test_checkpoint_manager.py"],
tags = ["team:ml", "exclusive"],
tags = ["team:ml", "exclusive", "new_storage"],
deps = [":train_lib"]
)

py_test(
name = "test_data_parallel_trainer",
size = "medium",
srcs = ["tests/test_data_parallel_trainer.py"],
tags = ["team:ml", "exclusive", "ray_air", "no_new_storage"],
tags = ["team:ml", "exclusive", "ray_air", "new_storage"],
deps = [":train_lib"]
)

py_test(
name = "test_data_parallel_trainer_checkpointing",
size = "medium",
srcs = ["tests/test_data_parallel_trainer_checkpointing.py"],
tags = ["team:ml", "exclusive", "ray_air", "no_new_storage"],
tags = ["team:ml", "exclusive", "ray_air", "new_storage"],
deps = [":train_lib"]
)

Expand Down Expand Up @@ -352,7 +352,7 @@ py_test(
name = "test_mosaic_trainer",
size = "medium",
srcs = ["tests/test_mosaic_trainer.py"],
tags = ["team:ml", "exclusive", "ray_air", "torch_1_11"],
tags = ["team:ml", "exclusive", "ray_air", "torch_1_11", "new_storage"],
deps = [":train_lib", ":conftest"]
)

Expand Down Expand Up @@ -496,7 +496,7 @@ py_test(
name = "test_tensorflow_trainer",
size = "medium",
srcs = ["tests/test_tensorflow_trainer.py"],
tags = ["team:ml", "exclusive", "ray_air", "no_new_storage"],
tags = ["team:ml", "exclusive", "ray_air", "new_storage"],
deps = [":train_lib"]
)

Expand Down
33 changes: 18 additions & 15 deletions python/ray/train/examples/mosaic_cifar10_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import argparse
from filelock import FileLock
import os
import tempfile

import torch
import torch.utils.data

import torchvision
from torchvision import transforms, datasets


import ray
from ray import train
from ray.train import ScalingConfig
Expand All @@ -29,19 +31,20 @@ def trainer_init_per_worker(config):
[transforms.ToTensor(), transforms.Normalize(mean, std)]
)

data_directory = "~/data"
train_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=True, download=True, transform=cifar10_transforms
),
list(range(BATCH_SIZE * 10)),
)
test_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=False, download=True, transform=cifar10_transforms
),
list(range(BATCH_SIZE * 10)),
)
data_directory = tempfile.mkdtemp(prefix="cifar_data")
with FileLock(os.path.join(data_directory, "data.lock")):
train_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=True, download=True, transform=cifar10_transforms
),
list(range(BATCH_SIZE * 10)),
)
test_dataset = torch.utils.data.Subset(
datasets.CIFAR10(
data_directory, train=False, download=True, transform=cifar10_transforms
),
list(range(BATCH_SIZE * 10)),
)

batch_size_per_worker = BATCH_SIZE // train.get_context().get_world_size()
train_dataloader = torch.utils.data.DataLoader(
Expand Down
91 changes: 6 additions & 85 deletions python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import io
import logging
import os
import time
from contextlib import redirect_stderr
import tempfile
from unittest.mock import patch

import numpy as np
import pytest

import ray
from ray import train, tune
from ray.train import Checkpoint, ScalingConfig
from ray.train import ScalingConfig
from ray.train._checkpoint import Checkpoint
from ray.air.constants import MAX_REPR_LENGTH
from ray.tune.impl import tuner_internal
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.gbdt_trainer import GBDTTrainer
from ray.train.trainer import BaseTrainer
from ray.util.placement_group import get_current_placement_group
from ray.train._internal.storage import _use_storage_context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -261,44 +255,6 @@ def setup(self):
trainer.fit()


@patch.dict(os.environ, {"RAY_LOG_TO_STDERR": "1"})
def _is_trainable_name_overriden(trainer: BaseTrainer):
trainable = trainer.as_trainable()
output = io.StringIO()

def say(self):
logger.warning("say")

trainable.say = say
with redirect_stderr(output):
remote_trainable = ray.remote(trainable)
remote_actor = remote_trainable.remote()
ray.get(remote_actor.say.remote())
time.sleep(1) # make sure logging gets caught
output = output.getvalue()
print(output)
assert trainable().__repr__() in output


def test_trainable_name_is_overriden_data_parallel_trainer(ray_start_4_cpus):
trainer = DataParallelTrainer(
lambda x: x, scaling_config=ScalingConfig(num_workers=1)
)

_is_trainable_name_overriden(trainer)


def test_trainable_name_is_overriden_gbdt_trainer(ray_start_4_cpus):
trainer = DummyGBDTTrainer(
params={},
label_column="__values__",
datasets={"train": ray.data.from_items([1, 2, 3])},
scaling_config=ScalingConfig(num_workers=1),
)

_is_trainable_name_overriden(trainer)


def test_repr(ray_start_4_cpus):
def training_loop(self):
pass
Expand All @@ -316,38 +272,14 @@ def training_loop(self):
assert len(representation) < MAX_REPR_LENGTH


def test_large_params(ray_start_4_cpus):
"""Tests if large arguments are can be serialized by the Trainer."""
array_size = int(1e8)

def training_loop(self):
checkpoint = self.starting_checkpoint.to_dict()["ckpt"]
assert len(checkpoint) == array_size

checkpoint = Checkpoint.from_dict({"ckpt": np.zeros(shape=array_size)})
trainer = DummyTrainer(training_loop, resume_from_checkpoint=checkpoint)
trainer.fit()


def test_metadata_propagation_base(ray_start_4_cpus):
if not _use_storage_context():
print("Not implemented in old backend")
return

from ray.train._checkpoint import Checkpoint as NewCheckpoint

class MyTrainer(BaseTrainer):
def training_loop(self):
assert train.get_context().get_metadata() == {"a": 1, "b": 1}
with tempfile.TemporaryDirectory() as path:
checkpoint = NewCheckpoint.from_directory(path)
checkpoint = Checkpoint.from_directory(path)
checkpoint.set_metadata({"b": 2, "c": 3})
train.report(
dict(
my_metric=1,
),
checkpoint=checkpoint,
)
train.report(dict(my_metric=1), checkpoint=checkpoint)

trainer = MyTrainer(metadata={"a": 1, "b": 1})
result = trainer.fit()
Expand All @@ -356,23 +288,12 @@ def training_loop(self):


def test_metadata_propagation_data_parallel(ray_start_4_cpus):
if not _use_storage_context():
print("Not implemented in old backend")
return

from ray.train._checkpoint import Checkpoint as NewCheckpoint

def training_loop(self):
assert train.get_context().get_metadata() == {"a": 1, "b": 1}
with tempfile.TemporaryDirectory() as path:
checkpoint = NewCheckpoint.from_directory(path)
checkpoint = Checkpoint.from_directory(path)
checkpoint.set_metadata({"b": 2, "c": 3})
train.report(
dict(
my_metric=1,
),
checkpoint=checkpoint,
)
train.report(dict(my_metric=1), checkpoint=checkpoint)

trainer = DummyTrainer(training_loop, metadata={"a": 1, "b": 1})
result = trainer.fit()
Expand Down
Loading

0 comments on commit 9ec4646

Please sign in to comment.