Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Fix use_gpu with HuggingFacePredictor #32333

Merged
merged 6 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions python/ray/train/huggingface/huggingface_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def from_checkpoint(
checkpoint: Checkpoint,
*,
pipeline_cls: Optional[Type[Pipeline]] = None,
use_gpu: bool = False,
**pipeline_kwargs,
) -> "HuggingFacePredictor":
"""Instantiate the predictor from a Checkpoint.
Expand All @@ -104,6 +105,8 @@ def from_checkpoint(
pipeline_cls: A ``transformers.pipelines.Pipeline`` class to use.
If not specified, will use the ``pipeline`` abstraction
wrapper.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
**pipeline_kwargs: Any kwargs to pass to the pipeline
initialization. If ``pipeline`` is None, this must contain
the 'task' argument. Cannot contain 'model'. Can be used
Expand All @@ -114,6 +117,9 @@ def from_checkpoint(
raise ValueError(
"If `pipeline_cls` is not specified, 'task' must be passed as a kwarg."
)
if use_gpu:
# default to using the GPU with the first index
pipeline_kwargs.setdefault("device", 0)
pipeline_cls = pipeline_cls or pipeline_factory
preprocessor = checkpoint.get_preprocessor()
with checkpoint.as_directory() as checkpoint_path:
Expand All @@ -123,14 +129,12 @@ def from_checkpoint(
return cls(
pipeline=pipeline,
preprocessor=preprocessor,
use_gpu=use_gpu,
)

def _predict(
self, data: Union[list, pd.DataFrame], **pipeline_call_kwargs
) -> pd.DataFrame:
if self.use_gpu:
# default to using the GPU with the first index
pipeline_call_kwargs.setdefault("device", 0)
ret = self.pipeline(data, **pipeline_call_kwargs)
# Remove unnecessary lists
try:
Expand Down Expand Up @@ -180,8 +184,7 @@ def predict(
data to use as features to predict on. If None, use all
columns.
**pipeline_call_kwargs: additional kwargs to pass to the
``pipeline`` object. If ``use_gpu`` is True, 'device'
will be set to 0 by default.
``pipeline`` object.
Examples:
>>> import pandas as pd
Expand Down
22 changes: 14 additions & 8 deletions python/ray/train/tests/test_huggingface_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import tempfile

import numpy as np
Expand Down Expand Up @@ -31,13 +30,25 @@ def create_checkpoint():
return HuggingFaceCheckpoint.from_dict(checkpoint.to_dict())


class AssertingHuggingFacePredictor(HuggingFacePredictor):
def __init__(self, pipeline=None, preprocessor=None, use_gpu: bool = False):
super().__init__(pipeline, preprocessor, use_gpu)
assert use_gpu
assert "cuda" in str(pipeline.device)


# TODO(ml-team): Add np.ndarray to batch_type
@pytest.mark.parametrize("batch_type", [pd.DataFrame])
@pytest.mark.parametrize("device", [None, 0])
def test_predict_batch(ray_start_4_cpus, caplog, batch_type, device):
checkpoint = create_checkpoint()
kwargs = {}

if device is not None:
kwargs["device"] = device

predictor = BatchPredictor.from_checkpoint(
checkpoint, HuggingFacePredictor, task="text-generation"
checkpoint, AssertingHuggingFacePredictor, task="text-generation", **kwargs
)

# Todo: Ray data does not support numpy string arrays well
Expand All @@ -50,12 +61,7 @@ def test_predict_batch(ray_start_4_cpus, caplog, batch_type, device):
else:
raise RuntimeError("Invalid batch_type")

kwargs = {}
if device:
kwargs["device"] = device
with caplog.at_level(logging.WARNING):
predictions = predictor.predict(dataset, num_gpus_per_worker=1, **kwargs)
assert "enable GPU prediction" not in caplog.text
predictions = predictor.predict(dataset, num_gpus_per_worker=1)

assert predictions.count() == 3

Expand Down