@@ -707,6 +707,282 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
707707 os .environ ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" ] = old_value
708708
709709
710+ @pytest .mark .skipif (
711+ not current_platform .has_device_capability (90 ),
712+ reason = "Batch invariance tests only supported on Hopper (SM90)" ,
713+ )
714+ @pytest .mark .skipif (
715+ not torch .cuda .is_available (),
716+ reason = "Requires CUDA to match production inference path." ,
717+ )
718+ @pytest .mark .parametrize ("backend" , ["FLASH_ATTN" ])
719+ @pytest .mark .forked
720+ def test_decode_logprobs_match_prefill_logprobs (backend ):
721+ """
722+ Test that verifies decode logprobs match prefill logprobs.
723+
724+ For each decoded token at position i:
725+ 1. Run decode to generate N tokens and collect their logprobs
726+ 2. For each position i in [0, N):
727+ - Take prefix = prompt + tokens[0:i]
728+ - Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
729+ - Verify prefill logprob matches decode logprob bitwise
730+
731+ This ensures that the logprobs from decode are consistent with what
732+ we would get if we ran prefill on each prefix.
733+ """
734+ backend = os .getenv ("VLLM_ATTENTION_BACKEND" , backend )
735+ os .environ ["VLLM_ATTENTION_BACKEND" ] = backend
736+
737+ seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
738+ random .seed (seed )
739+ model_name = os .getenv ("VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
740+ tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
741+
742+ from vllm .model_executor .layers .batch_invariant import (
743+ vllm_kernel_override_batch_invariant ,
744+ )
745+
746+ disable_custom_ar = vllm_kernel_override_batch_invariant ()
747+
748+ if disable_custom_ar :
749+ print (f"\n { '=' * 80 } " )
750+ print (f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={ tp_size } )" )
751+ print (f"{ '=' * 80 } \n " )
752+
753+ llm = LLM (
754+ model = model_name ,
755+ tensor_parallel_size = tp_size ,
756+ enable_prefix_caching = False ,
757+ max_num_seqs = 32 ,
758+ max_model_len = 8192 ,
759+ dtype = "bfloat16" ,
760+ )
761+
762+ # Use a few test prompts
763+ num_test_prompts = int (os .getenv ("VLLM_DECODE_PREFILL_NUM_PROMPTS" , "4" ))
764+ prompts = [_random_prompt (10 , 50 ) for _ in range (num_test_prompts )]
765+
766+ # Generate longer sequences to test multiple decode steps
767+ max_tokens = int (os .getenv ("VLLM_DECODE_PREFILL_MAX_TOKENS" , "16" ))
768+
769+ sp = SamplingParams (
770+ temperature = 0.0 , # Greedy for determinism
771+ max_tokens = max_tokens ,
772+ logprobs = 5 ,
773+ )
774+
775+ print ("\n " + "=" * 80 )
776+ print ("STEP 1: Running decode to generate tokens and collect logprobs" )
777+ print ("=" * 80 + "\n " )
778+
779+ # Step 1: Run decode and collect logprobs
780+ decode_outputs = llm .generate (prompts , sp , use_tqdm = False )
781+
782+ failed_comparisons = []
783+
784+ for prompt_idx , (prompt , decode_output ) in enumerate (zip (prompts , decode_outputs )):
785+ print (f"\n [Prompt { prompt_idx } ] Testing: { prompt [:80 ]} ..." )
786+
787+ # Extract decode logprobs and tokens
788+ decode_logprobs , token_ids = _extract_step_logprobs (decode_output )
789+ if decode_logprobs is None :
790+ pytest .skip (
791+ "Logprobs are not available on RequestOutput; "
792+ "enable logprobs return to run this test."
793+ )
794+
795+ print (f"[Prompt { prompt_idx } ] Generated { len (token_ids )} tokens: { token_ids } " )
796+ print (f"[Prompt { prompt_idx } ] Decode logprobs: { decode_logprobs .tolist ()} " )
797+
798+ # Step 2: For each token position, run prefill and compare
799+ print (f"\n [Prompt { prompt_idx } ] Verifying each token via prefill..." )
800+
801+ for token_idx in range (len (token_ids )):
802+ # Construct the prefix up to (but not including) this token
803+ current_token = token_ids [token_idx ]
804+
805+ # We need to detokenize to get the text prefix
806+ # For this, we'll use the tokenizer from the LLM
807+ # However, the LLM API doesn't expose tokenizer easily, so we'll
808+ # construct the prefix by decoding from the original prompt
809+
810+ # Get text up to this point by using the output text
811+ # This is approximate but should work for verification
812+ if token_idx == 0 :
813+ prefix_prompt = prompt
814+ else :
815+ # Use the partial output text up to this token
816+ # We'll need to construct this from the full output
817+ prefix_output = decode_output .outputs [0 ]
818+ # Get the text for tokens 0 to token_idx-1
819+ # Unfortunately, we don't have per-token text, so we'll use
820+ # a different approach: run prefill with prompt + tokens[0:token_idx]
821+
822+ # Actually, we need to get the actual text. Let's use a workaround:
823+ # Run a generation with max_tokens = token_idx to get that prefix
824+ prefix_sp = SamplingParams (
825+ temperature = 0.0 ,
826+ max_tokens = token_idx ,
827+ logprobs = 1 ,
828+ )
829+ prefix_output = llm .generate ([prompt ], prefix_sp , use_tqdm = False )[0 ]
830+ prefix_prompt = prompt + prefix_output .outputs [0 ].text
831+
832+ # Now run prefill with max_tokens=1 to get the logprob of the next token
833+ prefill_sp = SamplingParams (
834+ temperature = 0.0 ,
835+ max_tokens = 1 ,
836+ logprobs = 5 ,
837+ )
838+
839+ print (
840+ f" [Token { token_idx } ] Running prefill for prefix "
841+ f"(len={ len (prefix_prompt )} )..."
842+ )
843+ prefill_output = llm .generate ([prefix_prompt ], prefill_sp , use_tqdm = False )[
844+ 0
845+ ]
846+ prefill_logprobs , prefill_token_ids = _extract_step_logprobs (prefill_output )
847+
848+ if prefill_logprobs is None :
849+ print (f" [Token { token_idx } ] Warning: No prefill logprobs available" )
850+ continue
851+
852+ # The first token from prefill should match the current token
853+ prefill_token = prefill_token_ids [0 ]
854+ prefill_logprob = prefill_logprobs [0 ].item ()
855+ decode_logprob = decode_logprobs [token_idx ].item ()
856+
857+ print (
858+ f" [Token { token_idx } ] Decode token: { current_token } , "
859+ f"logprob: { decode_logprob :.8f} "
860+ )
861+ print (
862+ f" [Token { token_idx } ] Prefill token: { prefill_token } , "
863+ f"logprob: { prefill_logprob :.8f} "
864+ )
865+
866+ # Check if tokens match
867+ if current_token != prefill_token :
868+ failed_comparisons .append (
869+ {
870+ "prompt_idx" : prompt_idx ,
871+ "token_idx" : token_idx ,
872+ "reason" : "Token mismatch" ,
873+ "decode_token" : current_token ,
874+ "prefill_token" : prefill_token ,
875+ "decode_logprob" : decode_logprob ,
876+ "prefill_logprob" : prefill_logprob ,
877+ "prompt_text" : prompt [:100 ],
878+ "prefix_text" : prefix_prompt [:100 ],
879+ }
880+ )
881+ print (f" [Token { token_idx } ] ✗ TOKEN MISMATCH!" )
882+ continue
883+
884+ # Check if logprobs match bitwise
885+ if decode_logprob != prefill_logprob :
886+ diff = abs (decode_logprob - prefill_logprob )
887+ failed_comparisons .append (
888+ {
889+ "prompt_idx" : prompt_idx ,
890+ "token_idx" : token_idx ,
891+ "reason" : "Logprob mismatch" ,
892+ "decode_token" : current_token ,
893+ "prefill_token" : prefill_token ,
894+ "decode_logprob" : decode_logprob ,
895+ "prefill_logprob" : prefill_logprob ,
896+ "diff" : diff ,
897+ "prompt_text" : prompt [:100 ],
898+ "prefix_text" : prefix_prompt [:100 ],
899+ "decode_all_tokens" : token_ids ,
900+ "decode_all_logprobs" : decode_logprobs .tolist (),
901+ }
902+ )
903+ print (f" [Token { token_idx } ] ✗ LOGPROB MISMATCH! diff={ diff :.8e} " )
904+ else :
905+ print (f" [Token { token_idx } ] ✓ Match (bitwise equal)" )
906+
907+ # Print summary
908+ print (f"\n { '=' * 80 } " )
909+ if failed_comparisons :
910+ print (f"DECODE-PREFILL MISMATCH: { len (failed_comparisons )} failures detected" )
911+ print (f"{ '=' * 80 } " )
912+
913+ # Group failures by prompt for better readability
914+ failures_by_prompt : dict [int , list [dict ]] = {}
915+ for fail in failed_comparisons :
916+ pid = fail ["prompt_idx" ]
917+ if pid not in failures_by_prompt :
918+ failures_by_prompt [pid ] = []
919+ failures_by_prompt [pid ].append (fail )
920+
921+ for prompt_idx , failures in failures_by_prompt .items ():
922+ print (f"\n { '=' * 80 } " )
923+ print (f"PROMPT { prompt_idx } : { failures [0 ]['prompt_text' ]} ..." )
924+ print (f"{ '=' * 80 } " )
925+ print (f"Total failures for this prompt: { len (failures )} " )
926+
927+ # Show where mismatches occur (which token positions)
928+ mismatch_positions = [f ["token_idx" ] for f in failures ]
929+ print (f"Mismatch at token positions: { mismatch_positions } " )
930+
931+ # Show first few failures in detail
932+ for i , fail in enumerate (failures [:5 ]): # Show first 5 failures per prompt
933+ print (f"\n [Failure { i + 1 } ] Token position { fail ['token_idx' ]} :" )
934+ print (f" Reason: { fail ['reason' ]} " )
935+ print (f" Prefix text: '{ fail ['prefix_text' ]} ...'" )
936+ print (
937+ f" Decode: token={ fail ['decode_token' ]} , "
938+ f"logprob={ fail ['decode_logprob' ]:.10f} "
939+ )
940+ print (
941+ f" Prefill: token={ fail ['prefill_token' ]} , "
942+ f"logprob={ fail ['prefill_logprob' ]:.10f} "
943+ )
944+ if "diff" in fail :
945+ print (f" Difference: { fail ['diff' ]:.10e} " )
946+ # Show in hex to see bitwise difference
947+ import struct
948+
949+ decode_hex = struct .pack ("f" , fail ["decode_logprob" ]).hex ()
950+ prefill_hex = struct .pack ("f" , fail ["prefill_logprob" ]).hex ()
951+ print (f" Decode logprob (hex): 0x{ decode_hex } " )
952+ print (f" Prefill logprob (hex): 0x{ prefill_hex } " )
953+
954+ # If we have all tokens/logprobs, show the context
955+ if "decode_all_tokens" in fail and "decode_all_logprobs" in fail :
956+ token_idx = fail ["token_idx" ]
957+ all_tokens = fail ["decode_all_tokens" ]
958+ all_logprobs = fail ["decode_all_logprobs" ]
959+
960+ # Show context: 2 tokens before and after
961+ start = max (0 , token_idx - 2 )
962+ end = min (len (all_tokens ), token_idx + 3 )
963+
964+ print (f" Context (tokens { start } to { end - 1 } ):" )
965+ for j in range (start , end ):
966+ marker = " <-- MISMATCH" if j == token_idx else ""
967+ print (
968+ f" [{ j } ] token={ all_tokens [j ]} , "
969+ f"logprob={ all_logprobs [j ]:.8f} { marker } "
970+ )
971+
972+ if len (failures ) > 5 :
973+ print (f"\n ... and { len (failures ) - 5 } more failures for this prompt" )
974+
975+ print (f"\n { '=' * 80 } \n " )
976+
977+ pytest .fail (
978+ f"Decode logprobs do not match prefill logprobs: "
979+ f"{ len (failed_comparisons )} mismatches found."
980+ )
981+ else :
982+ print ("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!" )
983+ print (f"{ '=' * 80 } \n " )
984+
985+
710986def LLM_with_max_seqs (
711987 model : str ,
712988 max_num_seqs : int ,
0 commit comments