From 66fc8f730e002fb8b42f84cab2b1a6950870650e Mon Sep 17 00:00:00 2001 From: Joe G Date: Mon, 15 Jul 2024 11:46:49 -0700 Subject: [PATCH 01/12] Add head_size of 120 --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- benchmarks/kernels/benchmark_rope.py | 2 +- csrc/attention/attention_kernels.cu | 6 ++++++ csrc/cpu/attention.cpp | 6 ++++++ tests/kernels/test_attention.py | 2 +- tests/kernels/test_cache.py | 2 +- tests/kernels/test_pos_encoding.py | 2 +- vllm/attention/ops/ipex_attn.py | 2 +- vllm/attention/ops/paged_attn.py | 2 +- 9 files changed, 19 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 16de60477c305..fd6e8856bfd55 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -173,7 +173,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 78736c7a7ba6f..f542684a9a2a9 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -94,7 +94,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 192, 256], + choices=[64, 80, 96, 112, 120, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 91083481705cb..f824231b8e4ed 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -751,6 +751,9 @@ void paged_attention_v1_launcher( case 112: LAUNCH_PAGED_ATTENTION_V1(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; @@ -912,6 +915,9 @@ void paged_attention_v2_launcher( case 112: LAUNCH_PAGED_ATTENTION_V2(112); break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..54d16d170b5b7 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -387,6 +387,9 @@ void paged_attention_v1_impl_launcher( case 112: LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); break; + case 120: + LAUNCH_V1_ATTENTION_KERNEL(T, 120, BLOCK_SIZE); + break; case 128: LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; @@ -704,6 +707,9 @@ void paged_attention_v2_impl_launcher( case 112: LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); break; + case 120: + LAUNCH_V2_ATTENTION_KERNEL(T, 120, BLOCK_SIZE); + break; case 128: LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index f848ad51c7014..8c4f728258106 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -28,7 +28,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23b6baa60c05b..9cd5ca5e1bfe5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -11,7 +11,7 @@ NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [8, 16, 32] # Arbitrary values for testing diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 4c83659929d41..4a7ad6e0fa21d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,7 +10,7 @@ IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 5a5317b65004e..683aa57fce493 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -10,7 +10,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [64, 80, 96, 112, 120, 128, 256] @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index a214f40d16514..4e5dfd36cc2c1 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -31,7 +31,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 192, 256] + return [64, 80, 96, 112, 120, 128, 192, 256] @staticmethod def get_kv_cache_shape( From 8fb8fa0af81260a04d22c9270c950c225f43db35 Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 18 Jul 2024 03:56:25 -0700 Subject: [PATCH 02/12] Add Danube3 models to tests --- tests/models/test_big_models.py | 1 + tests/models/test_models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c3e48b56ee58f..a7f104f87ec06 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -17,6 +17,7 @@ "EleutherAI/gpt-j-6b", # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, + "h2oai/h2o-danube3-4b-base", ] #TODO: remove this after CPU float16 support ready diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4cd2cb665c8f0..fc26679357080 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,6 +20,7 @@ # "allenai/OLMo-1B", # Broken "bigcode/starcoder2-3b", "google/gemma-1.1-2b-it", + "h2oai/h2o-danube3-500m-base", ] From 2622935f15e39f507edde292f5b2e711bba872cf Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 18 Jul 2024 13:43:06 -0700 Subject: [PATCH 03/12] Remove CPU support and test fix --- csrc/cpu/attention.cpp | 6 ------ vllm/utils.py | 4 +++- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 3e02e9c6a51ad..abb4e3bea14bb 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -387,9 +387,6 @@ void paged_attention_v1_impl_launcher( case 112: LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); break; - case 120: - LAUNCH_V1_ATTENTION_KERNEL(T, 120, BLOCK_SIZE); - break; case 128: LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; @@ -707,9 +704,6 @@ void paged_attention_v2_impl_launcher( case 112: LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); break; - case 120: - LAUNCH_V2_ATTENTION_KERNEL(T, 120, BLOCK_SIZE); - break; case 128: LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; diff --git a/vllm/utils.py b/vllm/utils.py index 8be1528230b5f..84e2e06cef7b6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -521,7 +521,9 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_cache_shape = (num_blocks, num_heads, + head_size // x + int(bool(head_size % x)), block_size, + x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, From b6c1fcfa5aef8e222e06e169c2a36842690fc1dc Mon Sep 17 00:00:00 2001 From: Joe G Date: Thu, 18 Jul 2024 15:05:05 -0700 Subject: [PATCH 04/12] Revert test fix --- vllm/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 84e2e06cef7b6..8be1528230b5f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -521,9 +521,7 @@ def create_kv_caches_with_random( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() - key_cache_shape = (num_blocks, num_heads, - head_size // x + int(bool(head_size % x)), block_size, - x) + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, From d2eb2cc665f73a4d065d9f63e95058a0ab3cc8ba Mon Sep 17 00:00:00 2001 From: Joe G Date: Fri, 19 Jul 2024 23:49:53 +0000 Subject: [PATCH 05/12] Remove CPU support --- vllm/attention/ops/ipex_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 340083a850d05..81d308c4d4e22 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -10,7 +10,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 120, 128, 256] + return [64, 80, 96, 112, 128, 256] @staticmethod def get_kv_cache_shape( From 009ceccedff162e355350af157a88c8b600fc6aa Mon Sep 17 00:00:00 2001 From: Joe G Date: Mon, 22 Jul 2024 07:20:00 +0000 Subject: [PATCH 06/12] Test fix for fp8 failures fp8 kv cache values are encoded as uint8. The element size of uint8 is 1. 16 divided by any int over 1 is going to be 8 or less which is compatible with head size 120. But that doesn't happen with fp8, which leads to test failures for head size 120. --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 8be1528230b5f..111dd9cdb6604 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -520,7 +520,7 @@ def create_kv_caches_with_random( torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + x = 8 // torch.tensor([], dtype=torch_dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): From 3978100eb1b8cb92792b90a5b79c5a622975a721 Mon Sep 17 00:00:00 2001 From: Joe G Date: Mon, 22 Jul 2024 14:24:59 +0000 Subject: [PATCH 07/12] Better test fix for uint8 kv cache dtype --- vllm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index 111dd9cdb6604..9af26bf88bb08 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -520,7 +520,7 @@ def create_kv_caches_with_random( torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 - x = 8 // torch.tensor([], dtype=torch_dtype).element_size() + x = 16 // max(torch.tensor([], dtype=torch_dtype).element_size(), 2) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): From 662c9251032a72d9823656f52fb4c4b35b1c217a Mon Sep 17 00:00:00 2001 From: Joe G Date: Mon, 22 Jul 2024 20:47:07 +0000 Subject: [PATCH 08/12] Disable CPU tests for danube3-4b --- .buildkite/run-cpu-test.sh | 2 +- tests/models/test_big_models.py | 1 - tests/models/test_danube3_4b.py | 52 +++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_danube3_4b.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index a7678aae54644..d492418a96c8a 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -25,4 +25,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" docker exec cpu-test bash -c "cd tests; pip install pytest Pillow protobuf cd ../ - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba on CPU is not supported diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index a7f104f87ec06..c3e48b56ee58f 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -17,7 +17,6 @@ "EleutherAI/gpt-j-6b", # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, - "h2oai/h2o-danube3-4b-base", ] #TODO: remove this after CPU float16 support ready diff --git a/tests/models/test_danube3_4b.py b/tests/models/test_danube3_4b.py new file mode 100644 index 0000000000000..736282a3e2570 --- /dev/null +++ b/tests/models/test_danube3_4b.py @@ -0,0 +1,52 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +This tests danube3 separately because its head size isn't supported on CPU yet. + +Run `pytest tests/models/test_danube3_4b.py`. +""" +import pytest + +from .utils import check_outputs_equal + +MODELS = ["h2oai/h2o-danube3-4b-base"] + +target_dtype = "bfloat16" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [32]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", [target_dtype]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From 5ac2f9c83bea5d77b6dfb48f03f88ca517c641c4 Mon Sep 17 00:00:00 2001 From: Joe G Date: Tue, 23 Jul 2024 19:01:29 +0000 Subject: [PATCH 09/12] Disable fp8 tests if head size not divisible by 16 --- tests/kernels/test_cache.py | 6 ++++++ tests/models/test_danube3_4b.py | 2 +- vllm/utils.py | 8 +++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4aa18330f2f22..cdd068510eda0 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -52,6 +52,8 @@ def test_copy_blocks( kv_cache_dtype: str, device: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -124,6 +126,8 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -299,6 +303,8 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): diff --git a/tests/models/test_danube3_4b.py b/tests/models/test_danube3_4b.py index 736282a3e2570..bfaa275f73c19 100644 --- a/tests/models/test_danube3_4b.py +++ b/tests/models/test_danube3_4b.py @@ -10,7 +10,7 @@ MODELS = ["h2oai/h2o-danube3-4b-base"] -target_dtype = "bfloat16" +target_dtype = "half" @pytest.mark.parametrize("model", MODELS) diff --git a/vllm/utils.py b/vllm/utils.py index 1b0f95906821f..26df8fe625002 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -524,6 +524,12 @@ def create_kv_caches_with_random( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) @@ -531,7 +537,7 @@ def create_kv_caches_with_random( torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 - x = 16 // max(torch.tensor([], dtype=torch_dtype).element_size(), 2) + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: List[torch.Tensor] = [] for _ in range(num_layers): From 20b187839673d557d34fe9be314472213189ca3f Mon Sep 17 00:00:00 2001 From: Joe G Date: Tue, 23 Jul 2024 19:08:24 +0000 Subject: [PATCH 10/12] Remove test of h2o-danube3-500m --- tests/models/test_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index fc26679357080..4cd2cb665c8f0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -20,7 +20,6 @@ # "allenai/OLMo-1B", # Broken "bigcode/starcoder2-3b", "google/gemma-1.1-2b-it", - "h2oai/h2o-danube3-500m-base", ] From d54dafff83dad273a731735866e6b88f3a922601 Mon Sep 17 00:00:00 2001 From: Joe G Date: Tue, 23 Jul 2024 20:29:08 +0000 Subject: [PATCH 11/12] Fix paged attention test --- tests/kernels/test_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 5855e0b942dd1..c7c6707461c3e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -134,6 +134,8 @@ def test_paged_attention( seed: int, device: str, ) -> None: + if kv_cache_dtype == "fp8" and head_size % 16: + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): From d96f9b149a4213b60a21e6c17eb784c66b33f278 Mon Sep 17 00:00:00 2001 From: Joe G Date: Fri, 26 Jul 2024 23:08:21 +0000 Subject: [PATCH 12/12] Fix merge error --- .buildkite/run-cpu-test.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 7e78c2d3007d9..45bc8eb2f8477 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,7 +23,6 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest Pillow protobuf - cd ../ pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # online inference @@ -38,4 +37,4 @@ docker exec cpu-test bash -c " --model facebook/opt-125m \ --num-prompts 20 \ --endpoint /v1/completions \ - --tokenizer facebook/opt-125m" \ No newline at end of file + --tokenizer facebook/opt-125m"