diff --git a/truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py b/truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py index 2e6093977..9b19866b5 100644 --- a/truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py +++ b/truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py @@ -14,12 +14,19 @@ setup_environment_variables_and_secrets, ) +# NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file. +# Old Models will not have this file, so we check for it and use it if it exists. +# vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command. +# This logic is needed for any models trained using Transformers v4.51.3 or later VLLM_FULL_START_COMMAND = Template( - 'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ model_path }}' - + " --port 8000" - + " --tensor-parallel-size {{ specify_tensor_parallelism }}" - + " --dtype bfloat16" - + '"' + "sh -c '{% if envvars %}{{ envvars }} {% endif %}" + 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && ' + "if [ -f {{ model_path }}/chat_template.jinja ]; then " + " vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja " + " --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; " + "else " + " vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; " + "fi'" ) diff --git a/truss/tests/cli/train/test_deploy_checkpoints.py b/truss/tests/cli/train/test_deploy_checkpoints.py index e049fd636..336b112e6 100644 --- a/truss/tests/cli/train/test_deploy_checkpoints.py +++ b/truss/tests/cli/train/test_deploy_checkpoints.py @@ -505,8 +505,16 @@ def test_render_vllm_full_truss_config(): ) result = render_vllm_full_truss_config(deploy_config) - - expected_vllm_command = 'sh -c "HF_TOKEN=$(cat /secrets/hf_token) vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 --port 8000 --tensor-parallel-size 2 --dtype bfloat16"' + expected_vllm_command = ( + "sh -c 'HF_TOKEN=$(cat /secrets/hf_token) " + 'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && ' + "if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then " + "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 " + "--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja " + "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else " + "vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 " + "--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'" + ) assert isinstance(result, truss_config.TrussConfig) assert result.model_name == "test-full-model"