Skip to content

Conversation

@bwasti
Copy link
Contributor

@bwasti bwasti commented Oct 11, 2025

This PR replaces #26136 with a far more rigorous set of tests and implementation choices. It is, somewhat unfortunately, quite big. I will be trying to make this smaller over the weekend but would appreciate some initial eyes! (also I'll do a writeup because this was quite a journey to get working and I think folks would benefit from that)

Adds support for FLASH_ATTN, rms_norm, batched matmul, linear, fused_moe (the actual triton impl, not the native one), FLASH_ATTN_MLA, TRITON_MLA, allreduce (on NCCL, not the custom all reduce).

It also attemps to configure all relevant flags across the stack (including env variables) so users don't need to specify things like "disable_custom_ar" and "enforce_eager"

Purpose

Fully support Deepseek-v3 Batch Invariance on 8xH100s. This has large impact on mainstream models (including things like full multi-gpu support for Qwen30b-3a).

Test Plan

The biggest test is this:

VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA VLLM_TEST_TP_SIZE=8 VLLM_TEST_MODEL="deepseek-ai/DeepSeek-V3" VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 pytest -s -v tests/v1/generation/test_batch_invariance.py -k test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN]

Which runs hundreds of queries individually and in batched form, achieving exact bitwise alignment across every generated token (including sampling with temp=0.6)

Test Result

Pass.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Oct 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bwasti.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is an impressive and comprehensive pull request that systematically tackles the challenge of achieving bitwise batch invariance for Deepseek-v3. The changes are well-thought-out, spanning from low-level kernel modifications and environment variable settings to high-level configuration overrides. The approach of centralizing the batch invariance logic and providing deterministic implementations for key operations like matmul, softmax, and RMSNorm is excellent. The significantly improved and more rigorous test suite is also a major contribution that will help ensure correctness and prevent regressions.

I have found one critical issue in the moe_align_block_size implementation where a deterministic path seems to be unintentionally disabled. Please see the specific comment for details.

Overall, this is a high-quality contribution that will be very valuable for reproducible research and production use cases. Great work on this complex feature!

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@mergify mergify bot removed the needs-rebase label Oct 11, 2025
@bwasti bwasti force-pushed the det_deepseek_wip branch 3 times, most recently from 81c3ff6 to f6f8399 Compare October 11, 2025 02:40
@bwasti bwasti force-pushed the det_deepseek_wip branch 8 times, most recently from 755b05e to 355bda6 Compare October 16, 2025 01:06
chunked prefill setting

Signed-off-by: Bram Wasti <bwasti@meta.com>
@zhuohan123 zhuohan123 enabled auto-merge (squash) October 16, 2025 04:44
@zhuohan123 zhuohan123 disabled auto-merge October 16, 2025 05:05
@zhuohan123 zhuohan123 merged commit 7d8975d into vllm-project:main Oct 16, 2025
66 of 68 checks passed
mandy-li pushed a commit to mandy-li/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Oct 17, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
@SmartManoj
Copy link
Contributor

SmartManoj commented Oct 22, 2025

Context: #24583 (comment)
@freedom-cui, related commit: < 2d330a7

@freedom-cui
Copy link

Context: #24583 (comment) @freedom-cui, related commit: < 2d330a7

Hi @SmartManoj ~ Yes, I noticed this PR. As I mentioned when consulting with you earlier, my vllm was updated to the latest version on the main branch yesterday and already includes this commit. However, after reviewing: Batch-invariant Inference (view), I found that this target does not appear to be in a Done state yet.

@SmartManoj
Copy link
Contributor

Initially, A800 was supported?

Summary of that commit,
image

@freedom-cui
Copy link

Initially, A800 was supported?

I made changes to the code and got it running.

skip_unsupported = pytest.mark.skipif(
    not (current_platform.is_cuda() # and current_platform.has_device_capability(90)
         ),
    reason="Requires CUDA and >= Hopper (SM90)",
)

I am using Qwen3-30B-A3B, and TP=8, The results are as follows::

VLLM_GPU_MEMORY_UTILIZATION=0.45  VLLM_BATCH_INVARIANT=1 VLLM_TEST_TP_SIZE=8 VLLM_TEST_MODEL="Qwen3-30B-A3B" VLLM_TEST_SEED=12345 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  pytest -s -v tests/v1/generation/test_batch_invariance.py
===================================================================== FAILURES ======================================================================
________________________________________ test_v1_generation_is_deterministic_across_batch_sizes_with_needle _________________________________________

    @skip_unsupported
    @pytest.mark.timeout(1000)
    def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
        """
        Ensures that the same request (the 'needle' prompt) yields identical output
        whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
        using the high-level v1 LLM() API only (no manual batching).
    
        Strategy:
        - Create two LLM engines with identical config except max_num_seqs: 1 vs N.
        - Compute a baseline output for the needle prompt with the bs=1 engine.
        - For many trials, generate a batch (size N) where the needle appears at a
          random position among random filler prompts using the bs=N engine.
        - Track how many trials match vs mismatch, and report totals at the end.
          The test fails if any mismatches occur, but we still dump pass/fail
          counts.
    
        Notes:
        - Use seeded stochastic sampling with a fixed seed to test determinism.
        - Outputs are intentionally longer and sampled at higher temperature/top_p
          to produce a more random-sounding phrase, yet remain deterministic by
          seed.
        - Keep max_tokens and max_model_len bounded for speed and memory use.
        """
        seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
        random.seed(seed)
    
        # Allow overrides from environment (useful for CI tuning)
        # "facebook/opt-125m" is too small, doesn't reliably test determinism
        model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
        num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
        max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
        min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
        max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
        assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
    
        # Keep GPU memory usage low to avoid startup allocation failures.
        gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
        max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
    
        # Sampling parameters: longer outputs with a more random-sounding
        # continuation,but still deterministic due to fixed seed.
        temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
        top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
        max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
    
        sampling = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            seed=20240919,
        )
    
        needle_prompt = "There once was a "
    
        llm_bs1 = None
        llm_bsN = None
        try:
            # Engine with bs=1 behavior
            llm_bs1 = LLM_with_max_seqs(
                model=model,
                max_num_seqs=max_batch_size,
                gpu_memory_utilization=gpu_mem_util,
                max_model_len=max_model_len,
            )
    
            # Baseline generation for the needle prompt alone.
            baseline_out = llm_bs1.generate([needle_prompt], sampling)
            assert len(baseline_out) == 1
            assert len(baseline_out[0].outputs) >= 1
            baseline_text = baseline_out[0].outputs[0].text
    
            # Engine with larger batch limit (e.g., 64)
>           llm_bsN = LLM_with_max_seqs(
                model=model,
                max_num_seqs=max_batch_size,
                gpu_memory_utilization=gpu_mem_util,
                max_model_len=max_model_len,
            )

tests/v1/generation/test_batch_invariance.py:151: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/v1/generation/test_batch_invariance.py:990: in LLM_with_max_seqs
    return LLM(
vllm/entrypoints/llm.py:324: in __init__
    self.llm_engine = LLMEngine.from_engine_args(
vllm/v1/engine/llm_engine.py:188: in from_engine_args
    return cls(
vllm/v1/engine/llm_engine.py:122: in __init__
    self.engine_core = EngineCoreClient.make_client(
vllm/v1/engine/core_client.py:93: in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm/v1/engine/core_client.py:639: in __init__
    super().__init__(
vllm/v1/engine/core_client.py:468: in __init__
    with launch_core_engines(vllm_config, executor_class, log_stats) as (
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py:144: in __exit__
    next(self.gen)
vllm/v1/engine/utils.py:880: in launch_core_engines
    wait_for_engine_startup(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

handshake_socket = <zmq.Socket(zmq.ROUTER) at 0x7fc8c93fe580 closed>
addresses = EngineZmqAddresses(inputs=['ipc:///tmp/b5663e29-f854-4a3c-97b9-c2906d4b0912'], outputs=['ipc:///tmp/dc09a81f-b1c6-4a0b-b8d7-4505600b60bb'], coordinator_input=None, coordinator_output=None, frontend_stats_publish_address=None)
core_engines = [<vllm.v1.engine.utils.CoreEngine object at 0x7fc8c93d6630>]
parallel_config = ParallelConfig(pipeline_parallel_size=1, tensor_parallel_size=1, data_parallel_size=1, data_parallel_size_local=1, dat... rank=0, _data_parallel_master_port_list=[], decode_context_parallel_size=1, _api_process_count=1, _api_process_rank=0)
cache_config = CacheConfig(block_size=16, gpu_memory_utilization=0.9, swap_space=4.0, cache_dtype='auto', is_attention_free=False, nu...ache_dtype='auto', num_gpu_blocks=None, num_cpu_blocks=None, kv_sharing_fast_prefill=False, kv_cache_memory_bytes=None)
proc_manager = <vllm.v1.engine.utils.CoreEngineProcManager object at 0x7fc8c93d6f90>, coord_process = None

    def wait_for_engine_startup(
        handshake_socket: zmq.Socket,
        addresses: EngineZmqAddresses,
        core_engines: list[CoreEngine],
        parallel_config: ParallelConfig,
        cache_config: CacheConfig,
        proc_manager: CoreEngineProcManager | None,
        coord_process: Process | None,
    ):
        # Wait for engine core process(es) to send ready messages.
        local_count = parallel_config.data_parallel_size_local
        remote_count = len(core_engines) - local_count
        # [local, remote] counts
        conn_pending, start_pending = [local_count, remote_count], [0, 0]
        poller = zmq.Poller()
        poller.register(handshake_socket, zmq.POLLIN)
    
        remote_should_be_headless = (
            not parallel_config.data_parallel_hybrid_lb
            and not parallel_config.data_parallel_external_lb
        )
    
        if proc_manager is not None:
            for sentinel in proc_manager.sentinels():
                poller.register(sentinel, zmq.POLLIN)
        if coord_process is not None:
            poller.register(coord_process.sentinel, zmq.POLLIN)
        while any(conn_pending) or any(start_pending):
            events = poller.poll(STARTUP_POLL_PERIOD_MS)
            if not events:
                if any(conn_pending):
                    logger.debug(
                        "Waiting for %d local, %d remote core engine proc(s) to connect.",
                        *conn_pending,
                    )
                if any(start_pending):
                    logger.debug(
                        "Waiting for %d local, %d remote core engine proc(s) to start.",
                        *start_pending,
                    )
                continue
            if len(events) > 1 or events[0][0] != handshake_socket:
                # One of the local core processes exited.
                finished = proc_manager.finished_procs() if proc_manager else {}
                if coord_process is not None and coord_process.exitcode is not None:
                    finished[coord_process.name] = coord_process.exitcode
>               raise RuntimeError(
                    "Engine core initialization failed. "
                    "See root cause above. "
                    f"Failed core proc(s): {finished}"
                )
E               RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

vllm/v1/engine/utils.py:937: RuntimeError
___________________________________________ test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN] ___________________________________________

backend = 'FLASH_ATTN'

    @skip_unsupported
    @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
    @pytest.mark.forked
    def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
        backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
        os.environ["VLLM_ATTENTION_BACKEND"] = backend
    
        seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
        random.seed(seed)
        model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
        tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
    
        # For batch invariance, disable custom all-reduce to ensure deterministic
        # all-reduce operations (custom all-reduce may not be deterministic)
        from vllm.model_executor.layers.batch_invariant import (
            vllm_is_batch_invariant,
        )
    
        disable_custom_ar = vllm_is_batch_invariant()
    
        if disable_custom_ar:
            print(f"\n{'=' * 80}")
            print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
            print(f"{'=' * 80}\n")
    
>       llm = LLM(
            model=model_name,
            tensor_parallel_size=tp_size,
            enable_prefix_caching=False,
            max_num_seqs=32,
            max_model_len=8192,
            dtype="bfloat16",  # not everything is supported
        )

tests/v1/generation/test_batch_invariance.py:248: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
vllm/entrypoints/llm.py:324: in __init__
    self.llm_engine = LLMEngine.from_engine_args(
vllm/v1/engine/llm_engine.py:188: in from_engine_args
    return cls(
vllm/v1/engine/llm_engine.py:122: in __init__
    self.engine_core = EngineCoreClient.make_client(
vllm/v1/engine/core_client.py:93: in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
vllm/v1/engine/core_client.py:639: in __init__
    super().__init__(
vllm/v1/engine/core_client.py:468: in __init__
    with launch_core_engines(vllm_config, executor_class, log_stats) as (
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py:144: in __exit__
    next(self.gen)
vllm/v1/engine/utils.py:880: in launch_core_engines
    wait_for_engine_startup(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

handshake_socket = <zmq.Socket(zmq.ROUTER) at 0x7fc87b417e00 closed>
addresses = EngineZmqAddresses(inputs=['ipc:///tmp/6b5c1e7c-9706-4dbc-9906-16ec2a6a5a4d'], outputs=['ipc:///tmp/d899d292-2f74-4e55-b97f-a0993aba2169'], coordinator_input=None, coordinator_output=None, frontend_stats_publish_address=None)
core_engines = [<vllm.v1.engine.utils.CoreEngine object at 0x7fc87b160680>]
parallel_config = ParallelConfig(pipeline_parallel_size=1, tensor_parallel_size=8, data_parallel_size=1, data_parallel_size_local=1, dat... rank=0, _data_parallel_master_port_list=[], decode_context_parallel_size=1, _api_process_count=1, _api_process_rank=0)
cache_config = CacheConfig(block_size=16, gpu_memory_utilization=0.9, swap_space=4.0, cache_dtype='auto', is_attention_free=False, nu...ache_dtype='auto', num_gpu_blocks=None, num_cpu_blocks=None, kv_sharing_fast_prefill=False, kv_cache_memory_bytes=None)
proc_manager = <vllm.v1.engine.utils.CoreEngineProcManager object at 0x7fc8c829b260>, coord_process = None

    def wait_for_engine_startup(
        handshake_socket: zmq.Socket,
        addresses: EngineZmqAddresses,
        core_engines: list[CoreEngine],
        parallel_config: ParallelConfig,
        cache_config: CacheConfig,
        proc_manager: CoreEngineProcManager | None,
        coord_process: Process | None,
    ):
        # Wait for engine core process(es) to send ready messages.
        local_count = parallel_config.data_parallel_size_local
        remote_count = len(core_engines) - local_count
        # [local, remote] counts
        conn_pending, start_pending = [local_count, remote_count], [0, 0]
        poller = zmq.Poller()
        poller.register(handshake_socket, zmq.POLLIN)
    
        remote_should_be_headless = (
            not parallel_config.data_parallel_hybrid_lb
            and not parallel_config.data_parallel_external_lb
        )
    
        if proc_manager is not None:
            for sentinel in proc_manager.sentinels():
                poller.register(sentinel, zmq.POLLIN)
        if coord_process is not None:
            poller.register(coord_process.sentinel, zmq.POLLIN)
        while any(conn_pending) or any(start_pending):
            events = poller.poll(STARTUP_POLL_PERIOD_MS)
            if not events:
                if any(conn_pending):
                    logger.debug(
                        "Waiting for %d local, %d remote core engine proc(s) to connect.",
                        *conn_pending,
                    )
                if any(start_pending):
                    logger.debug(
                        "Waiting for %d local, %d remote core engine proc(s) to start.",
                        *start_pending,
                    )
                continue
            if len(events) > 1 or events[0][0] != handshake_socket:
                # One of the local core processes exited.
                finished = proc_manager.finished_procs() if proc_manager else {}
                if coord_process is not None and coord_process.exitcode is not None:
                    finished[coord_process.name] = coord_process.exitcode
>               raise RuntimeError(
                    "Engine core initialization failed. "
                    "See root cause above. "
                    f"Failed core proc(s): {finished}"
                )
E               RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}

vllm/v1/engine/utils.py:937: RuntimeError
================================================================= warnings summary ==================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/v1/generation/test_batch_invariance.py:79
  /home/cuizhisheng/precision_dir/vllm/tests/v1/generation/test_batch_invariance.py:79: PytestUnknownMarkWarning: Unknown pytest.mark.timeout - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.timeout(1000)

tests/v1/generation/test_batch_invariance.py:225
  /home/cuizhisheng/precision_dir/vllm/tests/v1/generation/test_batch_invariance.py:225: PytestUnknownMarkWarning: Unknown pytest.mark.forked - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.forked

tests/v1/generation/test_batch_invariance.py:486
  /home/cuizhisheng/precision_dir/vllm/tests/v1/generation/test_batch_invariance.py:486: PytestUnknownMarkWarning: Unknown pytest.mark.forked - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.forked

tests/v1/generation/test_batch_invariance.py:713
  /home/cuizhisheng/precision_dir/vllm/tests/v1/generation/test_batch_invariance.py:713: PytestUnknownMarkWarning: Unknown pytest.mark.forked - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.forked

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================== short test summary info ==============================================================
FAILED tests/v1/generation/test_batch_invariance.py::test_v1_generation_is_deterministic_across_batch_sizes_with_needle - RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
FAILED tests/v1/generation/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN] - RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
================================================ 2 failed, 5 passed, 6 warnings in 331.84s (0:05:31) ================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute
/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

@SmartManoj
Copy link
Contributor

Are there any tracebacks before that?

Would you test invariance_test.py?

alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: No status
Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants