Skip to content

Commit

Permalink
[Train] Fix use_gpu with HuggingFacePredictor (ray-project#32333)
Browse files Browse the repository at this point in the history
HuggingFacePredictor's use_gpu was set in the wrong method, causing it to not really work correctly. This PR fixes that.

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
Yard1 authored and edoakes committed Mar 22, 2023
1 parent acbee19 commit 6c96901
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
13 changes: 8 additions & 5 deletions python/ray/train/huggingface/huggingface_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def from_checkpoint(
checkpoint: Checkpoint,
*,
pipeline_cls: Optional[Type[Pipeline]] = None,
use_gpu: bool = False,
**pipeline_kwargs,
) -> "HuggingFacePredictor":
"""Instantiate the predictor from a Checkpoint.
Expand All @@ -104,6 +105,8 @@ def from_checkpoint(
pipeline_cls: A ``transformers.pipelines.Pipeline`` class to use.
If not specified, will use the ``pipeline`` abstraction
wrapper.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
**pipeline_kwargs: Any kwargs to pass to the pipeline
initialization. If ``pipeline`` is None, this must contain
the 'task' argument. Cannot contain 'model'. Can be used
Expand All @@ -114,6 +117,9 @@ def from_checkpoint(
raise ValueError(
"If `pipeline_cls` is not specified, 'task' must be passed as a kwarg."
)
if use_gpu:
# default to using the GPU with the first index
pipeline_kwargs.setdefault("device", 0)
pipeline_cls = pipeline_cls or pipeline_factory
preprocessor = checkpoint.get_preprocessor()
with checkpoint.as_directory() as checkpoint_path:
Expand All @@ -123,14 +129,12 @@ def from_checkpoint(
return cls(
pipeline=pipeline,
preprocessor=preprocessor,
use_gpu=use_gpu,
)

def _predict(
self, data: Union[list, pd.DataFrame], **pipeline_call_kwargs
) -> pd.DataFrame:
if self.use_gpu:
# default to using the GPU with the first index
pipeline_call_kwargs.setdefault("device", 0)
ret = self.pipeline(data, **pipeline_call_kwargs)
# Remove unnecessary lists
try:
Expand Down Expand Up @@ -180,8 +184,7 @@ def predict(
data to use as features to predict on. If None, use all
columns.
**pipeline_call_kwargs: additional kwargs to pass to the
``pipeline`` object. If ``use_gpu`` is True, 'device'
will be set to 0 by default.
``pipeline`` object.
Examples:
>>> import pandas as pd
Expand Down
22 changes: 14 additions & 8 deletions python/ray/train/tests/test_huggingface_gpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import tempfile

import numpy as np
Expand Down Expand Up @@ -31,13 +30,25 @@ def create_checkpoint():
return HuggingFaceCheckpoint.from_dict(checkpoint.to_dict())


class AssertingHuggingFacePredictor(HuggingFacePredictor):
def __init__(self, pipeline=None, preprocessor=None, use_gpu: bool = False):
super().__init__(pipeline, preprocessor, use_gpu)
assert use_gpu
assert "cuda" in str(pipeline.device)


# TODO(ml-team): Add np.ndarray to batch_type
@pytest.mark.parametrize("batch_type", [pd.DataFrame])
@pytest.mark.parametrize("device", [None, 0])
def test_predict_batch(ray_start_4_cpus, caplog, batch_type, device):
checkpoint = create_checkpoint()
kwargs = {}

if device is not None:
kwargs["device"] = device

predictor = BatchPredictor.from_checkpoint(
checkpoint, HuggingFacePredictor, task="text-generation"
checkpoint, AssertingHuggingFacePredictor, task="text-generation", **kwargs
)

# Todo: Ray data does not support numpy string arrays well
Expand All @@ -50,12 +61,7 @@ def test_predict_batch(ray_start_4_cpus, caplog, batch_type, device):
else:
raise RuntimeError("Invalid batch_type")

kwargs = {}
if device:
kwargs["device"] = device
with caplog.at_level(logging.WARNING):
predictions = predictor.predict(dataset, num_gpus_per_worker=1, **kwargs)
assert "enable GPU prediction" not in caplog.text
predictions = predictor.predict(dataset, num_gpus_per_worker=1)

assert predictions.count() == 3

Expand Down

0 comments on commit 6c96901

Please sign in to comment.