diff --git a/.buildkite/features/runai_model_streamer_loader.yml b/.buildkite/features/runai_model_streamer_loader.yml new file mode 100644 index 000000000..086bb1a75 --- /dev/null +++ b/.buildkite/features/runai_model_streamer_loader.yml @@ -0,0 +1,40 @@ +# runai_model_streamer_loader +steps: + - label: "Correctness tests for runai_model_streamer_loader" + key: "runai_model_streamer_loader_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_correctness + - label: "Record correctness test result for runai_model_streamer_loader" + key: "record_runai_model_streamer_loader_CorrectnessTest" + depends_on: "runai_model_streamer_loader_CorrectnessTest" + env: + CI_TARGET: "runai_model_streamer_loader" + CI_STAGE: "CorrectnessTest" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh runai_model_streamer_loader_CorrectnessTest + + - label: "Performance tests for runai_model_streamer_loader" + key: "runai_model_streamer_loader_PerformanceTest" + depends_on: "record_runai_model_streamer_loader_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - .buildkite/scripts/run_in_docker.sh python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_runai_model_streamer_loader.py::test_performance + - label: "Record performance test result for runai_model_streamer_loader" + key: "record_runai_model_streamer_loader_PerformanceTest" + depends_on: "runai_model_streamer_loader_PerformanceTest" + env: + CI_TARGET: "runai_model_streamer_loader" + CI_STAGE: "PerformanceTest" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh runai_model_streamer_loader_PerformanceTest diff --git a/requirements.txt b/requirements.txt index 9a48874cb..a60b1ee43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ torchvision==0.23.0 pathwaysutils parameterized numba==0.62.1 +runai-model-streamer[s3,gcs]==0.15.0 diff --git a/tests/e2e/test_runai_model_streamer_loader.py b/tests/e2e/test_runai_model_streamer_loader.py new file mode 100644 index 000000000..a1f434a17 --- /dev/null +++ b/tests/e2e/test_runai_model_streamer_loader.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import time + +import pytest +from vllm import LLM, SamplingParams + + +@pytest.fixture +def sampling_config(): + return SamplingParams(temperature=0, + max_tokens=10, + ignore_eos=True) + + +@pytest.fixture +# TODO(amacaskill): Replace with GKE owned GCS bucket. +def gcs_model_name(): + return "gs://vertex-model-garden-public-us/llama3/llama3-8b-hf" + + +@pytest.fixture +def hf_model_name(): + return "meta-llama/Meta-Llama-3-8B" + + +@pytest.fixture +def prompt(): + return "Hello, my name is" + + +def test_correctness( + sampling_config: SamplingParams, + gcs_model_name: str, + hf_model_name: str, + prompt: str, + monkeypatch: pytest.MonkeyPatch, +): + ''' + Compare the outputs of a model loaded from GCS via runai_model_streamer + and a model loaded from Hugging Face. The outputs should be the same. + These tests attempt to use tensor_parallel_size=1. The model is 16GB, + # and v6e has 32GB of HBM, so it will fit. + ''' + # Set ENV variables so that runai_model_streamer uses anonymous GCS access. + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "fake-project") + monkeypatch.setenv("RUNAI_STREAMER_GCS_USE_ANONYMOUS_CREDENTIALS", "true") + monkeypatch.setenv( + "CLOUD_STORAGE_EMULATOR_ENDPOINT", "https://storage.googleapis.com" + ) + gcs_llm = LLM(model=gcs_model_name, + load_format="runai_streamer", + max_model_len=128, + max_num_seqs=16, + max_num_batched_tokens=256) + gcs_outputs = gcs_llm.generate([prompt], sampling_config) + gcs_output_text = gcs_outputs[0].outputs[0].text + del gcs_llm + time.sleep(10) # Wait for TPUs to be released + + # Test with Hugging Face model + hf_llm = LLM(model=hf_model_name, + max_model_len=128, + max_num_seqs=16, + max_num_batched_tokens=256) + hf_outputs = hf_llm.generate([prompt], sampling_config) + hf_output_text = hf_outputs[0].outputs[0].text + del hf_llm + time.sleep(10) # Wait for TPUs to be released + + assert gcs_output_text == hf_output_text, ( + f"Outputs do not match! " + f"GCS output: {gcs_output_text}, HF output: {hf_output_text}" + ) + diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 32d7335c2..97b024a63 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -10,6 +10,9 @@ from transformers import PretrainedConfig from vllm.config import VllmConfig from vllm.utils.func_utils import supports_kw +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.runai_streamer_loader import RunaiModelStreamerLoader + from tpu_inference.layers.jax.sharding import ShardingAxisName from tpu_inference.logger import init_logger @@ -177,7 +180,22 @@ def create_sharded_model(): # the model creation again, otherwise the model forward will have # non-trivial overhead in PjitFunction. with mesh: - model.load_weights(rng) + loader = get_model_loader(vllm_config.load_config) + if isinstance(loader, RunaiModelStreamerLoader): + model_weights = vllm_config.model_config.model + if hasattr(vllm_config.model_config, "model_weights"): + model_weights = vllm_config.model_config.model_weights + weights_iterator = loader._get_weights_iterator(model_weights, vllm_config.model_config.revision) + # We set the weights iterator at runtime, to prevent having to change + # every model's load_weights signature. This also prevents us from hitting + # a TypeError at runtime if you use the RunaiModelStreamerLoader with any + # flax_nnx model whose load_weights function does not accept the + # weights_iterator keyword argument. + vllm_config.model_config.model_weights_iterator = weights_iterator + model.load_weights(rng) + del vllm_config.model_config.model_weights_iterator + else: + model.load_weights(rng) jit_model = create_jit_model( model, use_qwix_on_abstract_model=should_apply_qwix_on_abstract_model) diff --git a/tpu_inference/models/jax/utils/weight_utils.py b/tpu_inference/models/jax/utils/weight_utils.py index 64f026dae..fe606575a 100644 --- a/tpu_inference/models/jax/utils/weight_utils.py +++ b/tpu_inference/models/jax/utils/weight_utils.py @@ -13,10 +13,19 @@ import jax import jax.numpy as jnp import torch +import jax.dlpack from flax import nnx from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P +import torch +from jax.dlpack import from_dlpack +from torch.utils.dlpack import to_dlpack +from torchax.interop import jax_view from safetensors import safe_open +from vllm.config import VllmConfig + +import torch +import torchax from tpu_inference import utils from tpu_inference.logger import init_logger @@ -266,14 +275,16 @@ def get_default_maps(vllm_config, mesh: Mesh, bias_pad_map=bias_pad_keys) -def _load_hf_weights_on_thread(vllm_config, - params: nnx.State, - metadata_map: MetadataMap, - mesh: Mesh, - weights_file: str, - filter_regex: str | None = None, - keep_original_dtype_keys_regex: list[str] - | None = None): +def _load_and_shard_weight( + vllm_config, + params: nnx.State, + shardings: Any, + metadata_map: MetadataMap, + mesh: Mesh, + hf_key: str, + hf_weight: jax.Array, + keep_original_dtype_keys_regex: list[str] | None = None +): name_map = metadata_map.name_map reshape_keys = metadata_map.reshape_map bias_reshape_keys = metadata_map.bias_reshape_map @@ -290,154 +301,203 @@ def _load_hf_weights_on_thread(vllm_config, head_dim = utils.get_padded_head_dim(head_dim_original) head_dim_pad = head_dim - head_dim_original + # Check if the key should retain its original dtype + keep_original_dtype = False + if keep_original_dtype_keys_regex: + for pattern in keep_original_dtype_keys_regex: + if re.match(pattern, hf_key): + keep_original_dtype = True + break + + # Converting to config's dtype + if not keep_original_dtype and hf_weight.dtype != model_config.dtype: + logger.warning( + f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}" + ) + hf_weight = hf_weight.astype(model_config.dtype) + + if hf_key.endswith(".weight"): + hf_key = hf_key.removesuffix(".weight") + + # Find the corresponding model key using the HF key + if "layers" in hf_key: + layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) + layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key) + model_key = name_map[layer_key] + model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key) + elif "blocks" in hf_key: + layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1) + layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key) + model_key = name_map[layer_key] + model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) + else: + if hf_key not in name_map and hf_key == "lm_head": + logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings") + return + if hf_key not in name_map and "t2d" in hf_key: + logger.warning(f"Skip loading {hf_key} as it's not used in eagle-3 for now") + return + model_key = name_map.get(hf_key, hf_key) + + model_weight, model_sharding = get_param_and_sharding( + params, shardings, model_key + ) + + logger.debug( + "before transform | " + f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + ) + + if hf_key.endswith(".bias"): + for key in bias_reshape_keys: + if key in hf_key: + hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key]) + if head_dim_pad > 0: + hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad))) + break + else: + for key in reshape_keys: + if key in hf_key: + hf_weight = jnp.reshape(hf_weight, reshape_keys[key]) + if head_dim_pad > 0: + if "o_proj" in key: + hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0), (0, head_dim_pad))) + else: + hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad), (0, 0))) + break + for key in transpose_keys: + if key in hf_key: + hf_weight = jnp.transpose(hf_weight, transpose_keys[key]) + break + + # Pad num-kv-heads + if hf_key.endswith(".bias"): + for key, value in bias_pad_keys.items(): + dim = value[0] + dim_size = value[1] + if key in hf_key and dim_size != 0: + hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) + break + else: + for key, value in pad_keys.items(): + dim = value[0] + dim_size = value[1] + if key in hf_key and dim_size != 0: + hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) + break + + logger.debug( + "after transform | " + f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + ) + + if head_dim_pad == 0: + assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}" + + # Update the model weight + spec = model_weight.sharding.spec if isinstance( + model_weight.sharding, NamedSharding) else model_weight.sharding + model_weight.value = shard(hf_weight, spec) + + +def _load_hf_weights_on_thread( + vllm_config: VllmConfig, + params: nnx.State, + metadata_map: "MetadataMap", + mesh: Mesh, + weights_file: str, + filter_regex: Optional[str] = None, + keep_original_dtype_keys_regex: Optional[list[str]] = None, +): + """Loads weights from a single weights file.""" try: shardings = nnx.get_named_sharding(params, mesh) except TypeError: shardings = params for hf_key, hf_weight in model_weights_single_file_generator( - weights_file, framework="flax", filter_regex=filter_regex): - - # Check if the key should retain its original dtype - keep_original_dtype = False - if keep_original_dtype_keys_regex: - for pattern in keep_original_dtype_keys_regex: - if re.match(pattern, hf_key): - keep_original_dtype = True - break - - # Converting to config's dtype - if not keep_original_dtype and hf_weight.dtype != model_config.dtype: - logger.warning( - f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}" - ) - hf_weight = hf_weight.astype(model_config.dtype) - - if hf_key.endswith(".weight"): - hf_key = hf_key.removesuffix(".weight") - - # Find the corresponding model key using the HF key - if "layers" in hf_key: - layer_num = re.search(r"layers\.(\d+)", hf_key).group(1) - layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key) - model_key = name_map[layer_key] - model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key) - elif "blocks" in hf_key: - layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1) - layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key) - model_key = name_map[layer_key] - model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key) - else: - if hf_key not in name_map and hf_key == "lm_head": - logger.warning( - f"Skip loading {hf_key} due to tie_word_embeddings") - continue - if hf_key not in name_map and "t2d" in hf_key: - logger.warning( - f"Skip loading {hf_key} as it's not used in eagle-3 for now" - ) - continue - model_key = name_map.get(hf_key, hf_key) - model_weight, model_sharding = get_param_and_sharding( - params, shardings, model_key) - - logger.debug( - "before transform | " - f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" + weights_file, framework="flax", filter_regex=filter_regex + ): + _load_and_shard_weight( + vllm_config, + params, + shardings, + metadata_map, + mesh, + hf_key, + hf_weight, + keep_original_dtype_keys_regex, ) - if hf_key.endswith(".bias"): - for key in bias_reshape_keys: - if key in hf_key: - hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key]) - if head_dim_pad > 0: - hf_weight = jnp.pad(hf_weight, - ((0, 0), (0, head_dim_pad))) - break - else: - for key in reshape_keys: - if key in hf_key: - hf_weight = jnp.reshape(hf_weight, reshape_keys[key]) - if head_dim_pad > 0: - if "o_proj" in key: - hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0), - (0, head_dim_pad))) - else: - hf_weight = jnp.pad(hf_weight, - ((0, 0), (0, head_dim_pad), - (0, 0))) - break - for key in transpose_keys: - if key in hf_key: - hf_weight = jnp.transpose(hf_weight, transpose_keys[key]) - break - - # Pad num-kv-heads - if hf_key.endswith(".bias"): - for key, value in bias_pad_keys.items(): - dim = value[0] - dim_size = value[1] - if key in hf_key and dim_size != 0: - hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) - break - else: - for key, value in pad_keys.items(): - dim = value[0] - dim_size = value[1] - if key in hf_key and dim_size != 0: - hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim) - break - - logger.debug( - "after transform | " - f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}" - ) - if head_dim_pad == 0: - assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}" - - # Update the model weight - spec = model_weight.sharding.spec if isinstance( - model_weight.sharding, NamedSharding) else model_weight.sharding - model_weight.value = shard(hf_weight, spec) - - -def load_hf_weights(vllm_config, - model: nnx.Module, - metadata_map: MetadataMap, - mesh: Mesh, - filter_regex: str | None = None, - is_draft_model: bool = False, - keep_original_dtype_keys_regex: list[str] | None = None): - """Load weights from all model weights files to the model, run in multi threads.""" - if is_draft_model: - model_path = vllm_config.speculative_config.draft_model_config.model - else: - model_path = vllm_config.model_config.model - weights_files = get_model_weights_files( - model_path, vllm_config.load_config.download_dir) +def load_hf_weights( + vllm_config: VllmConfig, + model: nnx.Module, + metadata_map: "MetadataMap", + mesh: Mesh, + filter_regex: Optional[str] = None, + is_draft_model: bool = False, + keep_original_dtype_keys_regex: Optional[list[str]] = None, +): + """Load weights into a JAX model from either an iterator or files.""" params = nnx.state(model) - max_workers = min(64, len(weights_files)) - # NOTE(xiang): Disable multi-threading mode if running on multi-host. - # Because multi-threading would cause different JAX processes to load - # different weights at the same time. - if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray": - max_workers = 1 - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [ - executor.submit( - _load_hf_weights_on_thread, + try: + shardings = nnx.get_named_sharding(params, mesh) + except TypeError: + shardings = params + weights_iterator = None + if hasattr(vllm_config.model_config, "model_weights_iterator"): + weights_iterator = vllm_config.model_config.model_weights_iterator + env = torchax.default_env() + # The weights_iterator is used in RunAI model streamer integration. + if weights_iterator is not None: + for hf_key, hf_weight in weights_iterator: + if filter_regex and not re.match(filter_regex, hf_key): + continue + + # Since the weights_iterator yields Pytorch tensors (torch.Tensor), + # we need to convert them to JAX arrays (jax.Array). + hf_weight_jax = env.t2j_copy(hf_weight) + + _load_and_shard_weight( vllm_config, params, + shardings, metadata_map, mesh, - weights_file, - filter_regex=filter_regex, - keep_original_dtype_keys_regex=keep_original_dtype_keys_regex) - for weights_file in weights_files - ] - for future in futures: - future.result() + hf_key, + hf_weight_jax, + keep_original_dtype_keys_regex, + ) + else: + # File-based path (multi-threaded) + if is_draft_model: + model_path = vllm_config.speculative_config.draft_model_config.model + else: + model_path = vllm_config.model_config.model + weights_files = get_model_weights_files( + model_path, vllm_config.load_config.download_dir + ) + max_workers = min(64, len(weights_files)) + if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray": + max_workers = 1 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + _load_hf_weights_on_thread, + vllm_config, + params, + metadata_map, + mesh, + weights_file, + filter_regex=filter_regex, + keep_original_dtype_keys_regex=keep_original_dtype_keys_regex, + ) + for weights_file in weights_files + ] + for future in futures: + future.result() + check_all_loaded(params) nnx.update(model, params)