Skip to content

Commit

Permalink
Turned back on the Marlin tests (vllm-project#121)
Browse files Browse the repository at this point in the history
SUMMARY:
Turns back on the marlin tests. Issue was that vllm was not properly
tearing itself down. Calling the gc explicitly seems to have resolved
this in the short term.

In general, we should get to the bottom of why vllm does not shut down
cleanly.

TEST PLAN:
Automation
  • Loading branch information
robertgshaw2-redhat authored Mar 14, 2024
1 parent 66863b4 commit ac9c9c8
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tests/models/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest
import torch
import gc
from compare_utils import check_logprobs_close
from dataclasses import dataclass
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
Expand Down Expand Up @@ -45,7 +46,6 @@ class ModelPair:
]


@pytest.mark.skip(reason="out of memory")
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
reason="Marlin is not supported on this GPU type.")
Expand All @@ -67,24 +67,27 @@ def test_models(
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
# vllm memory cleanup is poor. This seems to fix things.
# NOTE: upstream sync should use downstream version.
del marlin_model.model.llm_engine.driver_worker
del marlin_model

gc.collect()
torch.cuda.empty_cache()

gptq_model = vllm_runner_nm(model_pair.model_gptq,
dtype=dtype,
max_model_len=MAX_MODEL_LEN)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
# vllm memory cleanup is poor. This seems to fix things.
# NOTE: upstream sync should use downstream version.
del gptq_model.model.llm_engine.driver_worker
del gptq_model
gc.collect()
torch.cuda.empty_cache()

# loop through the prompts
# use logprobs or else this will consistently run out of memory
Expand Down

0 comments on commit ac9c9c8

Please sign in to comment.