@@ -258,10 +258,13 @@ def create_kv_caches_with_random(
258258 key_cache = torch .empty (size = key_cache_shape ,
259259 dtype = torch_dtype ,
260260 device = device )
261- if cache_dtype in ["auto" , "half" , "bfloat16" , "float" ]:
262- key_cache .uniform_ (- scale , scale )
263- elif cache_dtype == 'fp8_e5m2' :
261+ if cache_dtype == 'fp8_e5m2' :
264262 _generate_random_fp8_e5m2 (key_cache , - scale , scale )
263+ elif torch_dtype in [torch .half , torch .bfloat16 , torch .float ]:
264+ key_cache .uniform_ (- scale , scale )
265+ else :
266+ raise ValueError (
267+ f"Does not support key cache of type { cache_dtype } " )
265268 key_caches .append (key_cache )
266269
267270 value_cache_shape = (num_blocks , num_heads , head_size , block_size )
@@ -270,9 +273,12 @@ def create_kv_caches_with_random(
270273 value_cache = torch .empty (size = value_cache_shape ,
271274 dtype = torch_dtype ,
272275 device = device )
273- if cache_dtype in ["auto" , "half" , "bfloat16" , "float" ]:
274- value_cache .uniform_ (- scale , scale )
275- elif cache_dtype == 'fp8_e5m2' :
276+ if cache_dtype == 'fp8_e5m2' :
276277 _generate_random_fp8_e5m2 (value_cache , - scale , scale )
278+ elif torch_dtype in [torch .half , torch .bfloat16 , torch .float ]:
279+ value_cache .uniform_ (- scale , scale )
280+ else :
281+ raise ValueError (
282+ f"Does not support value cache of type { cache_dtype } " )
277283 value_caches .append (value_cache )
278284 return key_caches , value_caches
0 commit comments