@@ -199,6 +199,10 @@ def _get_quantized_weights_iterator(
199199
200200 if self .pre_quant :
201201 if self .load_8bit :
202+ if current_platform .is_hpu ():
203+ raise ValueError (
204+ "currently hpu supports 4bit quantization only" )
205+
202206 return self ._quantized_8bit_generator (
203207 hf_weights_files , use_safetensors ,
204208 quant_state_dict ), quant_state_dict
@@ -302,6 +306,10 @@ def _parse_quant_state(param_name: str,
302306 in temp_state_dict ):
303307 quant_state = _parse_quant_state (mapped_weight_name ,
304308 temp_state_dict )
309+ if current_platform .is_hpu ():
310+ assert quant_state .quant_type == "nf4" , (
311+ "currently hpu supports nf4 quant_type only" )
312+
305313 quant_state_dict [mapped_weight_name ] = quant_state
306314 yield org_weight_name , weight_tensor
307315 else :
@@ -372,10 +380,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
372380 ...]
373381
374382 # bitsandbytes requires data in GPU
375- if weight_sub_tensor .is_cuda :
383+ if (weight_sub_tensor .is_cuda
384+ or weight_sub_tensor .device .type == "hpu" ):
376385 loaded_weight = weight_sub_tensor
377386 else :
378- loaded_weight = weight_sub_tensor .cuda ()
387+ loaded_weight = weight_sub_tensor .to (
388+ device = current_platform .device_type )
379389
380390 # remove the following after the issue is fixed:
381391 # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
0 commit comments