Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions truss/cli/train/deploy_checkpoints/deploy_full_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
setup_environment_variables_and_secrets,
)

# NB(aghilan): Transformers was recently changed to save a chat_template.jinja file instead of inside the tokenizer_config.json file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what version? might be helpful to know when it's safe to remove this logic/change this logic?

# Old Models will not have this file, so we check for it and use it if it exists.
# vLLM will not automatically resolve the chat_template.jinja file, so we need to pass it to the start command.
# This logic is needed for any models trained using Transformers v4.51.3 or later
VLLM_FULL_START_COMMAND = Template(
'sh -c "{%if envvars %}{{ envvars }} {% endif %}vllm serve {{ model_path }}'
+ " --port 8000"
+ " --tensor-parallel-size {{ specify_tensor_parallelism }}"
+ " --dtype bfloat16"
+ '"'
"sh -c '{% if envvars %}{{ envvars }} {% endif %}"
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
"if [ -f {{ model_path }}/chat_template.jinja ]; then "
" vllm serve {{ model_path }} --chat-template {{ model_path }}/chat_template.jinja "
" --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
"else "
" vllm serve {{ model_path }} --port 8000 --tensor-parallel-size {{ specify_tensor_parallelism }} --dtype bfloat16; "
"fi'"
)


Expand Down
12 changes: 10 additions & 2 deletions truss/tests/cli/train/test_deploy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,16 @@ def test_render_vllm_full_truss_config():
)

result = render_vllm_full_truss_config(deploy_config)

expected_vllm_command = 'sh -c "HF_TOKEN=$(cat /secrets/hf_token) vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 --port 8000 --tensor-parallel-size 2 --dtype bfloat16"'
expected_vllm_command = (
"sh -c 'HF_TOKEN=$(cat /secrets/hf_token) "
'HF_TOKEN="$$(cat /secrets/hf_access_token)" && export HF_TOKEN && '
"if [ -f /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja ]; then "
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
"--chat-template /tmp/training_checkpoints/job123/rank-0/checkpoint-1/chat_template.jinja "
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; else "
"vllm serve /tmp/training_checkpoints/job123/rank-0/checkpoint-1 "
"--port 8000 --tensor-parallel-size 2 --dtype bfloat16; fi'"
)

assert isinstance(result, truss_config.TrussConfig)
assert result.model_name == "test-full-model"
Expand Down
Loading