diff --git a/examples/llm/hf_pipeline_dolly/config.py b/examples/llm/hf_pipeline_dolly/config.py index 4a80a7116..eca03bec9 100644 --- a/examples/llm/hf_pipeline_dolly/config.py +++ b/examples/llm/hf_pipeline_dolly/config.py @@ -15,6 +15,7 @@ from functools import lru_cache from langchain import HuggingFacePipeline +from torch.cuda import device_count from nemoguardrails.llm.helpers import get_llm_instance_wrapper from nemoguardrails.llm.providers import register_llm_provider @@ -25,8 +26,8 @@ def get_dolly_v2_3b_llm(): repo_id = "databricks/dolly-v2-3b" params = {"temperature": 0, "max_length": 1024} - # Using the first GPU - device = 0 + # Use the first CUDA-enabled GPU, if any + device = 0 if device_count() else -1 llm = HuggingFacePipeline.from_model_id( model_id=repo_id, device=device, task="text-generation", model_kwargs=params