From ae1198054d5050a3d265f75d26094aea293f3f2a Mon Sep 17 00:00:00 2001 From: serhatgktp Date: Mon, 3 Jul 2023 15:53:32 -0400 Subject: [PATCH] Check for CUDA devices before specifying which one to use Signed-off-by: serhatgktp --- examples/llm/hf_pipeline_dolly/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/llm/hf_pipeline_dolly/config.py b/examples/llm/hf_pipeline_dolly/config.py index 4a80a7116..5987fb4ee 100644 --- a/examples/llm/hf_pipeline_dolly/config.py +++ b/examples/llm/hf_pipeline_dolly/config.py @@ -15,6 +15,7 @@ from functools import lru_cache from langchain import HuggingFacePipeline +from torch.cuda import device_count from nemoguardrails.llm.helpers import get_llm_instance_wrapper from nemoguardrails.llm.providers import register_llm_provider @@ -25,8 +26,8 @@ def get_dolly_v2_3b_llm(): repo_id = "databricks/dolly-v2-3b" params = {"temperature": 0, "max_length": 1024} - # Using the first GPU - device = 0 + # Using the first CUDA-enabled GPU, if any + device = 0 if device_count() else -1 llm = HuggingFacePipeline.from_model_id( model_id=repo_id, device=device, task="text-generation", model_kwargs=params