Skip to content

Commit

Permalink
Check for CUDA devices before specifying which one to use
Browse files Browse the repository at this point in the history
Signed-off-by: serhatgktp <efkan@ibm.com>
  • Loading branch information
serhatgktp committed Jul 3, 2023
1 parent 40889af commit ae11980
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions examples/llm/hf_pipeline_dolly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import lru_cache

from langchain import HuggingFacePipeline
from torch.cuda import device_count

from nemoguardrails.llm.helpers import get_llm_instance_wrapper
from nemoguardrails.llm.providers import register_llm_provider
Expand All @@ -25,8 +26,8 @@ def get_dolly_v2_3b_llm():
repo_id = "databricks/dolly-v2-3b"
params = {"temperature": 0, "max_length": 1024}

# Using the first GPU
device = 0
# Using the first CUDA-enabled GPU, if any
device = 0 if device_count() else -1

llm = HuggingFacePipeline.from_model_id(
model_id=repo_id, device=device, task="text-generation", model_kwargs=params
Expand Down

0 comments on commit ae11980

Please sign in to comment.