Skip to content

Commit 9369e12

Browse files
committed
ok, fixed the torchax.view.item() issue.
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent e45b220 commit 9369e12

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

tests/lora/test_layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,15 @@ def create_column_parallel_packed_layer():
231231
if repeats == 2:
232232
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
233233
linear = MergedColumnParallelLinear(
234-
256, # input_size
235-
[256] * repeats, # output_size
234+
64, # input_size
235+
[64] * repeats, # output_size
236236
bias=False,
237237
params_dtype=torch.float16)
238238
linear.weight.data = torch.rand_like(linear.weight.data)
239239

240240
base_linear = MergedColumnParallelLinear(
241-
256, # input_size
242-
[256] * repeats, # output_size
241+
64, # input_size
242+
[64] * repeats, # output_size
243243
bias=False,
244244
params_dtype=torch.float16)
245245
base_linear.weight.data = linear.weight.data
@@ -303,13 +303,13 @@ def create_column_parallel_packed_layer():
303303
repeats=repeats,
304304
)
305305

306-
# inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 256].
306+
# inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
307307
# index_mapping: list[int]
308308
# prompt_mapping: list[int]
309309
inputs, index_mapping, prompt_mapping = create_random_inputs(
310310
active_lora_ids=list(lora_dict.keys()),
311311
num_inputs=32,
312-
input_size=(1, 256),
312+
input_size=(1, 64),
313313
input_range=(0, 1),
314314
input_type=torch.float16,
315315
device='cpu')
@@ -372,7 +372,7 @@ def create_column_parallel_packed_layer():
372372
inputs, index_mapping, prompt_mapping = create_random_inputs(
373373
active_lora_ids=[0], # different from the above create_random_inputs
374374
num_inputs=32,
375-
input_size=(1, 256),
375+
input_size=(1, 64),
376376
input_range=(0, 1),
377377
input_type=torch.float16,
378378
device='cpu')

tpu_inference/lora/torch_punica_tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.nn.functional as F
99
import torchax
1010
from vllm.lora.punica_wrapper.utils import convert_mapping
11+
from torchax.interop import jax_view, torch_view
12+
1113

1214
if TYPE_CHECKING:
1315
# avoid circuit import
@@ -283,7 +285,7 @@ def _update_prefill_metadata(self,
283285
self.batch_size = 1
284286
self._lora_indices_per_batch[:self.
285287
batch_size] = token_lora_tensor[:self.
286-
batch_size]
288+
batch_size].torch()
287289

288290
def _pad_prompt_mapping(
289291
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:

0 commit comments

Comments
 (0)