@@ -71,11 +71,11 @@ class CustomAllreduce:
7171 # TODO: We should set a reasonable range for FP.
7272 MB = 1024 * 1024
7373 _QR_MIN_SIZE = {
74- (torch .float16 , 2 ): [16 * MB , 2 * MB , 2 * MB , 1 * MB ],
75- (torch .float16 , 4 ): [16 * MB , 64 * MB , 4 * MB , 2 * MB ],
74+ (torch .float16 , 2 ): [1 * MB , 2 * MB , 2 * MB , 1 * MB ],
75+ (torch .float16 , 4 ): [1 * MB , 64 * MB , 4 * MB , 2 * MB ],
7676 (torch .float16 , 8 ): [16 * MB , 4 * MB , 4 * MB , 2 * MB ],
77- (torch .bfloat16 , 2 ): [16 * MB , 8 * MB , 8 * MB , 8 * MB ],
78- (torch .bfloat16 , 4 ): [16 * MB , 128 * MB , 128 * MB , 16 * MB ],
77+ (torch .bfloat16 , 2 ): [2 * MB , 8 * MB , 8 * MB , 8 * MB ],
78+ (torch .bfloat16 , 4 ): [8 * MB , 128 * MB , 128 * MB , 16 * MB ],
7979 (torch .bfloat16 , 8 ): [16 * MB , 2048 * MB , 2048 * MB , 2048 * MB ],
8080 }
8181
@@ -256,40 +256,43 @@ def init_custom_quick_allreduce(self):
256256 Initialize a custom quick allreduce implementation for AMD
257257 based on quick reduce (https://github.com/mk1-project/quickreduce).
258258 """
259+ if not self ._QR_SHOULD_INIT :
260+ return
261+ self .use_fp16_kernels = envs .VLLM_ROCM_QR_CAST_BF16_TO_FP16
262+ regime_str = envs .VLLM_ROCM_QR_QUANT_REGIME
263+ if regime_str not in QuickReduceRegime .__members__ :
264+ logger .warning (
265+ "Custom quick allreduce:" ,
266+ f"Invalid quantization level: { regime_str } . "
267+ "Supported levels: "
268+ f"{ list (QuickReduceRegime .__members__ .keys ())} " )
269+ return
270+
271+ if regime_str == "NONE" :
272+ logger .debug ("Custom quick allreduce is disabled based "
273+ "on env variable VLLM_ROCM_QR_QUANT_REGIME" )
274+ return
275+
259276 vllm_config = get_current_vllm_config ()
260- # for test mode
261- if vllm_config is not None and hasattr (vllm_config , "model_config" ):
277+ if vllm_config is not None and \
278+ hasattr (vllm_config , "model_config" ) and \
279+ hasattr (vllm_config .model_config , "dtype" ):
262280 dtype = vllm_config .model_config .dtype
263281 if dtype not in [torch .float16 , torch .bfloat16 ]:
264282 self ._QR_SHOULD_INIT = False
265- # On RocM bfloat16 kernels are slower than fp16
266- # due to slower match operations
267- # If environment is not set to 1 we convert input to fp16
268- self .use_fp16_kernels : bool = envs .VLLM_ROCM_QR_CAST_BF16_TO_FP16
269- regime_str = envs .VLLM_ROCM_QR_QUANT_REGIME
270- if self ._QR_SHOULD_INIT :
271- if regime_str not in QuickReduceRegime .__members__ :
272- logger .warning (
273- "Custom quick allreduce:" ,
274- f"Invalid quantization level: { regime_str } . "
275- "Supported levels: "
276- f"{ list (QuickReduceRegime .__members__ .keys ())} " )
277- return
278-
279- if regime_str == "NONE" :
280- logger .debug ("Custom quick allreduce is disabled based "
281- "on env variable VLLM_ROCM_QR_QUANT_REGIME" )
282- return
283-
284- self .qr_quant_level = QuickReduceRegime [regime_str ]
285- self ._qr_ptr = ops .init_custom_qr (self .rank , self .world_size )
286- self .create_qr_shared_buffer ()
283+ # On RocM bfloat16 kernels are slower than fp16
284+ # due to slower match operations
285+ # If environment variable is not set to 1 we convert input to fp16
287286 if dtype == torch .bfloat16 and not self .use_fp16_kernels :
288287 logger .info (
289288 "Custom quick allreduce: converting bf16 inputs to "
290289 "fp16 can improve performance"
291290 "set envs.VLLM_ROCM_QR_CAST_BF16_TO_FP16=1 to turn on." )
292- self .qr_disabled = False
291+
292+ self .qr_quant_level = QuickReduceRegime [regime_str ]
293+ self ._qr_ptr = ops .init_custom_qr (self .rank , self .world_size )
294+ self .create_qr_shared_buffer ()
295+ self .qr_disabled = False
293296
294297 @contextmanager
295298 def capture (self ):
@@ -346,7 +349,7 @@ def should_quick_allreduce(self, inp: torch.Tensor):
346349 if self .use_fp16_kernels :
347350 dtype = torch .float16
348351 return inp_size <= self .qr_max_size and \
349- inp_size > self ._QR_MIN_SIZE [(dtype , self .world_size )]\
352+ inp_size >= self ._QR_MIN_SIZE [(dtype , self .world_size )]\
350353 [self .qr_quant_level .value ]
351354
352355 def should_custom_allreduce (self , inp : torch .Tensor ):
@@ -369,7 +372,7 @@ def should_custom_ar(self, inp: torch.Tensor):
369372 return self .should_quick_allreduce (
370373 inp ) or self .should_custom_allreduce (inp )
371374
372- def cr_all_reduce (self ,
375+ def ca_all_reduce (self ,
373376 inp : torch .Tensor ,
374377 * ,
375378 out : torch .Tensor = None ,
@@ -411,7 +414,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
411414 if self .should_custom_allreduce (input ):
412415 if self ._IS_CAPTURING :
413416 if torch .cuda .is_current_stream_capturing ():
414- return self .cr_all_reduce (input , registered = True )
417+ return self .ca_all_reduce (input , registered = True )
415418 else :
416419 # If warm up, mimic the allocation pattern since custom
417420 # allreduce is out-of-place.
@@ -421,7 +424,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
421424 # incurs a cost of cudaMemcpy, which should be small
422425 # (<=1% of overall latency) compared to the performance
423426 # gain of using custom kernels
424- return self .cr_all_reduce (input , registered = False )
427+ return self .ca_all_reduce (input , registered = False )
425428
426429 return None
427430
0 commit comments