Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] H2O Danube3-4b #6451

Merged
merged 19 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported

# online inference
docker exec cpu-test bash -c "
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def benchmark_rope_kernels_multi_lora(
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
Expand Down
6 changes: 6 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
Expand Down Expand Up @@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]

BLOCK_SIZES = [16, 32]
Expand Down Expand Up @@ -134,6 +134,8 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down
8 changes: 7 additions & 1 deletion tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32]

# Arbitrary values for testing
Expand Down Expand Up @@ -52,6 +52,8 @@ def test_copy_blocks(
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down Expand Up @@ -124,6 +126,8 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down Expand Up @@ -325,6 +329,8 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
Expand Down
52 changes: 52 additions & 0 deletions tests/models/test_danube3_4b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.

This tests danube3 separately because its head size isn't supported on CPU yet.

Run `pytest tests/models/test_danube3_4b.py`.
"""
import pytest

from .utils import check_outputs_equal

MODELS = ["h2oai/h2o-danube3-4b-base"]

target_dtype = "half"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
2 changes: 1 addition & 1 deletion vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 192, 256]
return [64, 80, 96, 112, 120, 128, 192, 256]

@staticmethod
def get_kv_cache_shape(
Expand Down
6 changes: 6 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ def create_kv_caches_with_random(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:

if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)

torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
Expand Down
Loading