File tree Expand file tree Collapse file tree 3 files changed +22
-0
lines changed Expand file tree Collapse file tree 3 files changed +22
-0
lines changed Original file line number Diff line number Diff line change @@ -370,6 +370,8 @@ def forward( # type: ignore
370370 value : torch .Tensor ,
371371 kv_cache : torch .Tensor ,
372372 attn_metadata : DualChunkFlashAttentionMetadata ,
373+ output : Optional [torch .Tensor ] = None ,
374+ output_scale : Optional [torch .Tensor ] = None ,
373375 ) -> torch .Tensor :
374376 """Forward pass with DualChunkFlashAttention.
375377 Args:
@@ -383,6 +385,13 @@ def forward( # type: ignore
383385 Returns:
384386 shape = [num_tokens, num_heads * head_size]
385387 """
388+ assert output is None , "Output tensor not supported for DualChunk"
389+
390+ if output_scale is not None :
391+ raise NotImplementedError (
392+ "fused output quantization is not yet supported"
393+ " for FlashAttentionImpl" )
394+
386395 (
387396 query ,
388397 query_succ ,
Original file line number Diff line number Diff line change @@ -547,6 +547,7 @@ def forward(
547547 kv_cache : torch .Tensor ,
548548 attn_metadata : FlashInferMetadata ,
549549 output : Optional [torch .Tensor ] = None ,
550+ output_scale : Optional [torch .Tensor ] = None ,
550551 ) -> torch .Tensor :
551552 """Forward pass with FlashInfer.
552553
@@ -561,6 +562,11 @@ def forward(
561562 """
562563 assert output is not None , "Output tensor must be provided."
563564
565+ if output_scale is not None :
566+ raise NotImplementedError (
567+ "fused output quantization is not yet supported"
568+ " for FlashInferImpl" )
569+
564570 if attn_metadata is None :
565571 # Profiling run.
566572 return output
Original file line number Diff line number Diff line change @@ -414,6 +414,7 @@ def forward(
414414 kv_cache : torch .Tensor ,
415415 attn_metadata : FlexAttentionMetadata ,
416416 output : Optional [torch .Tensor ] = None ,
417+ output_scale : Optional [torch .Tensor ] = None ,
417418 ) -> torch .Tensor :
418419 """Forward pass with FLexAttention.
419420
@@ -427,6 +428,12 @@ def forward(
427428 shape = [num_tokens, num_heads * head_size]
428429 """
429430 assert output is not None , "Output tensor must be provided."
431+
432+ if output_scale is not None :
433+ raise NotImplementedError (
434+ "fused output quantization is not yet supported"
435+ " for FlexAttentionImpl" )
436+
430437 enable_gqa = self .num_kv_heads != self .num_heads
431438
432439 if attn_metadata is None :
You can’t perform that action at this time.
0 commit comments