Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions tpu_inference/platforms/tpu_jax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
import torch
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast

import jax.numpy as jnp
Expand Down Expand Up @@ -28,10 +29,12 @@

logger = init_logger(__name__)

_DTYPE: dict[str, jnp.dtype] = {
_DTYPE: dict[str | torch.dtype, jnp.dtype] = {
torch.bfloat16: jnp.bfloat16,
"bfloat16": jnp.bfloat16,
"float": jnp.float32,
"float32": jnp.float32,
torch.float32: jnp.float32,
}


Expand Down Expand Up @@ -146,14 +149,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# For mm model preprocessors, it may need the output dtype to be torch.
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
if not isinstance(vllm_config.model_config.dtype, str):
logger.warning(
"The model dtype is not properly set for JAX backend. "
"Overwriting it to jnp.bfloat16")
vllm_config.model_config.dtype = jnp.bfloat16
else:
vllm_config.model_config.dtype = _DTYPE.get(
vllm_config.model_config.dtype, jnp.bfloat16)
vllm_config.model_config.dtype = _DTYPE.get(
vllm_config.model_config.dtype, jnp.bfloat16)

if impl == "vllm":
vllm_config.model_config.dtype = j2t_dtype(
Expand Down
16 changes: 7 additions & 9 deletions tpu_inference/worker/tpu_worker_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from typing import Callable, Dict, Optional, Tuple, Union
Expand All @@ -8,6 +7,7 @@
import jax.numpy as jnp
import jaxlib
import jaxtyping
import torch
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
Expand Down Expand Up @@ -38,10 +38,12 @@

logger = init_logger(__name__)

_DTYPE: dict[str, jnp.dtype] = {
_DTYPE: dict[str | torch.dtype, jnp.dtype] = {
"bfloat16": jnp.bfloat16,
torch.bfloat16: jnp.bfloat16,
"float": jnp.float32,
"float32": jnp.float32,
torch.float32: jnp.float32,
}


Expand All @@ -61,14 +63,10 @@ def __init__(self,
# with torch version of the dtype.
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
if impl != "vllm": # vllm-pytorch implementation does not need this conversion

# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
if not isinstance(vllm_config.model_config.dtype, str):
logger.warning(
"The model dtype is not properly set for JAX backend. "
"Overwriting it to jnp.bfloat16")
vllm_config.model_config.dtype = jnp.bfloat16
else:
# dtype can be converted to jnp.dtype in tpu_jax.py
if not jax.dtypes.issubdtype(vllm_config.model_config.dtype,
jnp.generic):
vllm_config.model_config.dtype = _DTYPE.get(
vllm_config.model_config.dtype, jnp.bfloat16)

Expand Down
Loading