@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
326326 return kv_cache
327327
328328
329+ # We can move this function to a common utils file if it's also useful for other
330+ # hardware.
331+ def dtype_bits (dtype : torch .dtype ):
332+ if dtype .is_floating_point :
333+ try :
334+ return torch .finfo (dtype ).bits
335+ except TypeError :
336+ pass
337+ elif dtype .is_complex :
338+ if dtype is torch .complex32 :
339+ return 32
340+ elif dtype is torch .complex64 :
341+ return 64
342+ elif dtype is torch .complex128 :
343+ return 128
344+ else :
345+ try :
346+ return torch .iinfo (dtype ).bits
347+ # torch.iinfo cannot support int4, int2, bits8...
348+ except TypeError :
349+ pass
350+ str_dtype = str (dtype )
351+ # support torch.int4, torch.int5, torch.uint5...
352+ if str_dtype .startswith ("torch.int" ) or str_dtype .startswith ("torch.uint" ):
353+ return int (str_dtype [- 1 ])
354+ raise TypeError (f"Getting the bit width of { dtype } is not supported" )
355+
356+
357+ def get_dtype_packing (dtype ):
358+ bits = dtype_bits (dtype )
359+ if 32 % bits != 0 :
360+ raise ValueError (
361+ f"The bit width must be divisible by 32, but got bits={ bits } , "
362+ "dtype={dtype}" )
363+ return 32 // bits
364+
365+
329366def get_page_size_bytes (block_size : int , num_kv_heads : int , head_size : int ,
330367 kv_cache_dtype : torch .dtype ) -> int :
331368 """Returns the size in bytes of one page of the KV cache."""
332- return block_size * num_kv_heads * head_size * kv_cache_dtype .itemsize
369+ padded_head_size = cdiv (head_size ,
370+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
371+ num_combined_kv_heads = num_kv_heads * 2
372+
373+ # NOTE: for the implicit padding in XLA
374+ packing = get_dtype_packing (kv_cache_dtype )
375+ num_combined_kv_heads = cdiv (num_combined_kv_heads , packing ) * packing
376+
377+ kv_cache_dtype_bits = dtype_bits (kv_cache_dtype )
378+ return (block_size * num_combined_kv_heads * padded_head_size *
379+ kv_cache_dtype_bits // 8 )
0 commit comments