@@ -231,15 +231,15 @@ def create_column_parallel_packed_layer():
231231 if repeats == 2 :
232232 # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
233233 linear = MergedColumnParallelLinear (
234- 256 , # input_size
235- [256 ] * repeats , # output_size
234+ 64 , # input_size
235+ [64 ] * repeats , # output_size
236236 bias = False ,
237237 params_dtype = torch .float16 )
238238 linear .weight .data = torch .rand_like (linear .weight .data )
239239
240240 base_linear = MergedColumnParallelLinear (
241- 256 , # input_size
242- [256 ] * repeats , # output_size
241+ 64 , # input_size
242+ [64 ] * repeats , # output_size
243243 bias = False ,
244244 params_dtype = torch .float16 )
245245 base_linear .weight .data = linear .weight .data
@@ -303,13 +303,13 @@ def create_column_parallel_packed_layer():
303303 repeats = repeats ,
304304 )
305305
306- # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256 ].
306+ # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64 ].
307307 # index_mapping: list[int]
308308 # prompt_mapping: list[int]
309309 inputs , index_mapping , prompt_mapping = create_random_inputs (
310310 active_lora_ids = list (lora_dict .keys ()),
311311 num_inputs = 32 ,
312- input_size = (1 , 256 ),
312+ input_size = (1 , 64 ),
313313 input_range = (0 , 1 ),
314314 input_type = torch .float16 ,
315315 device = 'cpu' )
@@ -372,7 +372,7 @@ def create_column_parallel_packed_layer():
372372 inputs , index_mapping , prompt_mapping = create_random_inputs (
373373 active_lora_ids = [0 ], # different from the above create_random_inputs
374374 num_inputs = 32 ,
375- input_size = (1 , 256 ),
375+ input_size = (1 , 64 ),
376376 input_range = (0 , 1 ),
377377 input_type = torch .float16 ,
378378 device = 'cpu' )
0 commit comments