diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index cf96461a549f3..3eb7d186eb009 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -40,6 +40,8 @@ def maybe_backend_fallback( guided_params.backend = "outlines" if guided_params.backend == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + xgr_installed) # xgrammar only has x86 wheels for linux, fallback to outlines from vllm.platforms import current_platform if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: @@ -77,6 +79,13 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" + # If the xgrammar module cannot be imported successfully, + # we should still allow users to use guided decoding with a fallback. + elif not xgr_installed: + logger.warning("xgrammar module cannot be imported successfully. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + if (guided_params.backend == "outlines" and guided_params.json_object is not None): # outlines doesn't support json_object, fallback to xgrammar diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index c01bd3af1d5b9..fc3a4cd4bebc8 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -14,7 +14,9 @@ try: import xgrammar as xgr from xgrammar.base import _core as xgr_core + xgr_installed = True except ImportError: + xgr_installed = False pass from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,