Skip to content

Commit

Permalink
Allow explicitly plumbing through nics (#2605) (#2608)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Oct 7, 2022
1 parent 55edae6 commit 76a9c16
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 40 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:

name: py${{ matrix.python-version }}, torch-${{ matrix.pytorch-version }}, ${{ matrix.test-markers }}, ${{ matrix.os }}

timeout-minutes: 60
timeout-minutes: 80
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -156,7 +156,7 @@ jobs:

- name: Tests
run: |
RUN_PRIVATE=1 pytest -v --timeout 300 --durations 100 -m "$MARKERS" --junitxml pytest.xml tests
RUN_PRIVATE=1 LUDWIG_TEST_SUITE_TIMEOUT_S=3600 pytest -v --timeout 300 --durations 100 -m "$MARKERS" --junitxml pytest.xml tests
- name: Upload Unit Test Results
if: always()
Expand Down
96 changes: 58 additions & 38 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@
import contextlib
import copy
import logging
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import dask
import numpy as np
import pandas as pd
import ray
import ray.train as rt
import torch
import tqdm
from fsspec.config import conf
from packaging import version
from pyarrow.fs import FSSpecHandler, PyFileSystem
from ray import ObjectRef
from ray.data.dataset_pipeline import DatasetPipeline
from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV
from ray.train.horovod import HorovodConfig
from ray.train.trainer import Trainer
from ray.util.dask import ray_dask_get
from ray.util.placement_group import placement_group, remove_placement_group

Expand All @@ -57,11 +61,6 @@
from ludwig.utils.torch_utils import get_torch_device, initialize_pytorch
from ludwig.utils.types import Series

_ray112 = version.parse("1.12") <= version.parse(ray.__version__) < version.parse("1.13")

import ray.train as rt # noqa: E402
from ray.train.trainer import Trainer # noqa: E402

logger = logging.getLogger(__name__)

try:
Expand All @@ -70,10 +69,6 @@
logger.warn(f"ImportError (ray.py) from horovod.ray import RayExecutor failed with error: \n\t{e}")
RayExecutor = None

if _ray112:
from ludwig.backend._ray112_compat import HorovodConfig
else:
from ray.train.horovod import HorovodConfig
RAY_DEFAULT_PARALLELISM = 200
FIFTEEN_MINS_IN_S = 15 * 60

Expand All @@ -95,31 +90,38 @@ def get_horovod_kwargs(use_gpu=None):
)


def get_trainer_kwargs(use_gpu=None):
def _num_nodes() -> int:
node_resources = [node["Resources"] for node in ray.nodes()]
return len(node_resources)


def get_trainer_kwargs(**kwargs) -> Dict[str, Any]:
kwargs = copy.deepcopy(kwargs)

# Our goal is to have a worker per resource used for training.
# The priority is GPUs, but can fall back to CPUs if there are no
# GPUs available.
if use_gpu is None:
use_gpu = int(ray.cluster_resources().get("GPU", 0)) > 0

use_gpu = kwargs.get("use_gpu", int(ray.cluster_resources().get("GPU", 0)) > 0)
if use_gpu:
num_workers = int(ray.cluster_resources().get("GPU", 0))
else:
# TODO: use placement groups or otherwise spread across nodes
node_resources = [node["Resources"] for node in ray.nodes()]
num_workers = len(node_resources)
num_workers = _num_nodes()

return dict(
# TODO travis: replace backend here once ray 1.8 released
# backend='horovod',
backend=HorovodConfig(),
# Explicitly override network interfaces Horovod will attempt to use
nics = kwargs.pop("nics", None)
if nics is not None:
nics = set(nics)

defaults = dict(
backend=HorovodConfig(nics=nics),
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker={
"CPU": 0 if use_gpu else 1,
"GPU": 1 if use_gpu else 0,
},
)
return {**defaults, **kwargs}


def _create_dask_engine(**kwargs):
Expand Down Expand Up @@ -331,6 +333,37 @@ def process_results(self, results: List[Dict], **info) -> None:
self.progess_bars[_id].update(update_by)


@contextlib.contextmanager
def spread_env(use_gpu: bool = False, num_workers: int = 1, **kwargs):
if TRAIN_ENABLE_WORKER_SPREAD_ENV in os.environ:
# User set this explicitly, so honor their selection
yield
return

try:
if not use_gpu and num_workers > 1:
# When doing CPU-only training, default to a SPREAD policy to avoid
# packing too many workers on a single machine
os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV] = "1"
yield
finally:
if TRAIN_ENABLE_WORKER_SPREAD_ENV in os.environ:
del os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV]


@contextlib.contextmanager
def create_runner(**kwargs):
trainer_kwargs = get_trainer_kwargs(**kwargs)
with spread_env(**trainer_kwargs):
trainer = Trainer(**trainer_kwargs)

trainer.start()
try:
yield trainer
finally:
trainer.shutdown()


@register_ray_trainer("trainer", MODEL_ECD, default=True)
class RayTrainerV2(BaseTrainer):
def __init__(
Expand All @@ -352,15 +385,6 @@ def __init__(
def get_schema_cls():
return ECDTrainerConfig

@contextlib.contextmanager
def create_runner(self):
trainer = Trainer(**{**get_trainer_kwargs(), **self.trainer_kwargs})
trainer.start()
try:
yield trainer
finally:
trainer.shutdown()

def train(
self,
training_set: RayDataset,
Expand All @@ -382,7 +406,7 @@ def train(
if test_set is not None:
dataset["test"] = test_set.pipeline(shuffle=False, **self.data_loader_kwargs)

with self.create_runner() as runner:
with create_runner(**self.trainer_kwargs) as runner:
results, self._validation_field, self._validation_metric = runner.run(
lambda config: train_fn(**config),
config={"executable_kwargs": executable_kwargs, "model_ref": ray.put(self.model), **kwargs},
Expand Down Expand Up @@ -465,7 +489,7 @@ def eval_batch_size(self, value: int):

@property
def resources_per_worker(self) -> Dict[str, Any]:
trainer_kwargs = {**get_trainer_kwargs(), **self.trainer_kwargs}
trainer_kwargs = get_trainer_kwargs(**self.trainer_kwargs)
return trainer_kwargs.get("resources_per_worker", {})

@property
Expand Down Expand Up @@ -633,7 +657,7 @@ def __init__(
self.df_engine = df_engine

def get_trainer_kwargs(self) -> Dict[str, Any]:
return {**get_trainer_kwargs(), **self.trainer_kwargs}
return get_trainer_kwargs(**self.trainer_kwargs)

def get_resources_per_worker(self) -> Tuple[int, int]:
trainer_kwargs = self.get_trainer_kwargs()
Expand Down Expand Up @@ -697,9 +721,7 @@ def batch_evaluation(
# communication ops. However, Horovod is not suitable for transforming one big dataset to another. For that
# we will use Ray Datasets. Therefore, we break this up into two separate steps, and two passes over the
# dataset. In the future, we can explore ways to combine these into a single step to reduce IO.
runner = Trainer(**{**get_trainer_kwargs(), **self.trainer_kwargs})
runner.start()
try:
with create_runner(**self.trainer_kwargs) as runner:
# Collect eval metrics by distributing work across nodes / gpus with Horovod
datasets = {"eval": dataset.pipeline(shuffle=False, **self.data_loader_kwargs)}
predictor_kwargs = {
Expand All @@ -717,8 +739,6 @@ def batch_evaluation(
},
dataset=datasets,
)[0]
finally:
runner.shutdown()

predictions = None
if collect_predictions:
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import os
import tempfile
import time
import uuid
from unittest import mock

Expand All @@ -24,6 +25,19 @@
from ludwig.hyperopt.run import hyperopt
from tests.integration_tests.utils import category_feature, generate_data, text_feature

TEST_SUITE_TIMEOUT_S = int(os.environ.get("LUDWIG_TEST_SUITE_TIMEOUT_S", 3600))


def pytest_sessionstart(session):
session.start_time = time.time()


@pytest.fixture(autouse=True)
def check_session_time(request):
elapsed = time.time() - request.session.start_time
if elapsed > TEST_SUITE_TIMEOUT_S:
request.session.shouldstop = "time limit reached: %0.2f seconds" % elapsed


@pytest.fixture(autouse=True)
def setup_tests(request):
Expand Down
128 changes: 128 additions & 0 deletions tests/ludwig/backend/test_ray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import copy
import os
from unittest.mock import patch

import pytest

# Skip these tests if Ray is not installed
ray = pytest.importorskip("ray") # noqa

from ray.train.constants import TRAIN_ENABLE_WORKER_SPREAD_ENV # noqa
from ray.train.horovod import HorovodConfig # noqa

from ludwig.backend.ray import get_trainer_kwargs, spread_env # noqa

# Mark the entire module as distributed
pytestmark = pytest.mark.distributed


@pytest.mark.parametrize(
"trainer_config,cluster_resources,num_nodes,expected_kwargs",
[
# Prioritize using the GPU when available over multi-node
(
{},
{"CPU": 4, "GPU": 1},
2,
dict(
backend=HorovodConfig(),
num_workers=1,
use_gpu=True,
resources_per_worker={
"CPU": 0,
"GPU": 1,
},
),
),
# Use one worker per node for CPU, chck NIC override
(
{"nics": [""]},
{"CPU": 4, "GPU": 0},
2,
dict(
backend=HorovodConfig(nics={""}),
num_workers=2,
use_gpu=False,
resources_per_worker={
"CPU": 1,
"GPU": 0,
},
),
),
# Allow explicitly setting GPU usage for autoscaling clusters
(
{"use_gpu": True, "num_workers": 2},
{"CPU": 4, "GPU": 0},
1,
dict(
backend=HorovodConfig(),
num_workers=2,
use_gpu=True,
resources_per_worker={
"CPU": 0,
"GPU": 1,
},
),
),
# Allow overriding resources_per_worker
(
{"resources_per_worker": {"CPU": 2, "GPU": 1}},
{"CPU": 4, "GPU": 2},
2,
dict(
backend=HorovodConfig(),
num_workers=2,
use_gpu=True,
resources_per_worker={
"CPU": 2,
"GPU": 1,
},
),
),
],
)
def test_get_trainer_kwargs(trainer_config, cluster_resources, num_nodes, expected_kwargs):
with patch("ludwig.backend.ray.ray.cluster_resources", return_value=cluster_resources):
with patch("ludwig.backend.ray._num_nodes", return_value=num_nodes):
trainer_config_copy = copy.deepcopy(trainer_config)
actual_kwargs = get_trainer_kwargs(**trainer_config_copy)

# Function should not modify the original input
assert trainer_config_copy == trainer_config

actual_backend = actual_kwargs.pop("backend")
expected_backend = expected_kwargs.pop("backend")

assert type(actual_backend) == type(expected_backend)
assert actual_backend.nics == expected_backend.nics
assert actual_kwargs == expected_kwargs


@pytest.mark.parametrize(
"trainer_kwargs,current_env_value,expected_env_value",
[
({"use_gpu": False, "num_workers": 2}, None, "1"),
({"use_gpu": False, "num_workers": 1}, None, None),
({"use_gpu": True, "num_workers": 2}, None, None),
({"use_gpu": True, "num_workers": 2}, "1", "1"),
({"use_gpu": True, "num_workers": 2}, "", ""),
],
)
def test_spread_env(trainer_kwargs, current_env_value, expected_env_value):
prev_env = os.environ.get(TRAIN_ENABLE_WORKER_SPREAD_ENV)

# Set environment to state prior to override
if current_env_value is not None:
os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV] = current_env_value
elif TRAIN_ENABLE_WORKER_SPREAD_ENV in os.environ:
del os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV]

with spread_env(**trainer_kwargs):
assert os.environ.get(TRAIN_ENABLE_WORKER_SPREAD_ENV) == expected_env_value
assert os.environ.get(TRAIN_ENABLE_WORKER_SPREAD_ENV) == current_env_value

# Return environment to original state
if prev_env is not None:
os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV] = prev_env
elif TRAIN_ENABLE_WORKER_SPREAD_ENV in os.environ:
del os.environ[TRAIN_ENABLE_WORKER_SPREAD_ENV]

0 comments on commit 76a9c16

Please sign in to comment.