diff --git a/tpu_inference/platforms/tpu_jax.py b/tpu_inference/platforms/tpu_jax.py index c6030fe1b..47a6984c8 100644 --- a/tpu_inference/platforms/tpu_jax.py +++ b/tpu_inference/platforms/tpu_jax.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +import torch from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import jax.numpy as jnp @@ -28,10 +29,12 @@ logger = init_logger(__name__) -_DTYPE: dict[str, jnp.dtype] = { +_DTYPE: dict[str | torch.dtype, jnp.dtype] = { + torch.bfloat16: jnp.bfloat16, "bfloat16": jnp.bfloat16, "float": jnp.float32, "float32": jnp.float32, + torch.float32: jnp.float32, } @@ -146,14 +149,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # For mm model preprocessors, it may need the output dtype to be torch. # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm": - if not isinstance(vllm_config.model_config.dtype, str): - logger.warning( - "The model dtype is not properly set for JAX backend. " - "Overwriting it to jnp.bfloat16") - vllm_config.model_config.dtype = jnp.bfloat16 - else: - vllm_config.model_config.dtype = _DTYPE.get( - vllm_config.model_config.dtype, jnp.bfloat16) + vllm_config.model_config.dtype = _DTYPE.get( + vllm_config.model_config.dtype, jnp.bfloat16) if impl == "vllm": vllm_config.model_config.dtype = j2t_dtype( diff --git a/tpu_inference/worker/tpu_worker_jax.py b/tpu_inference/worker/tpu_worker_jax.py index a7d2ddda5..defa6d841 100644 --- a/tpu_inference/worker/tpu_worker_jax.py +++ b/tpu_inference/worker/tpu_worker_jax.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - import os import tempfile from typing import Callable, Dict, Optional, Tuple, Union @@ -8,6 +7,7 @@ import jax.numpy as jnp import jaxlib import jaxtyping +import torch import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, @@ -38,10 +38,12 @@ logger = init_logger(__name__) -_DTYPE: dict[str, jnp.dtype] = { +_DTYPE: dict[str | torch.dtype, jnp.dtype] = { "bfloat16": jnp.bfloat16, + torch.bfloat16: jnp.bfloat16, "float": jnp.float32, "float32": jnp.float32, + torch.float32: jnp.float32, } @@ -61,14 +63,10 @@ def __init__(self, # with torch version of the dtype. impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() if impl != "vllm": # vllm-pytorch implementation does not need this conversion - # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing - if not isinstance(vllm_config.model_config.dtype, str): - logger.warning( - "The model dtype is not properly set for JAX backend. " - "Overwriting it to jnp.bfloat16") - vllm_config.model_config.dtype = jnp.bfloat16 - else: + # dtype can be converted to jnp.dtype in tpu_jax.py + if not jax.dtypes.issubdtype(vllm_config.model_config.dtype, + jnp.generic): vllm_config.model_config.dtype = _DTYPE.get( vllm_config.model_config.dtype, jnp.bfloat16)