Skip to content

Commit c17bace

Browse files
committed
add multi-chip test case
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 5ce5c35 commit c17bace

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ steps:
120120
- |
121121
.buildkite/scripts/run_in_docker.sh \
122122
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
123-
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py'
123+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_bgmv.py && \
124+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
124125
125126
- label: "E2E MLPerf tests for JAX + vLLM models on multiple chips"
126127
key: test_11
@@ -157,7 +158,8 @@ steps:
157158
commands:
158159
- |
159160
.buildkite/scripts/run_in_docker.sh \
160-
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py'
161+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora.py && \
162+
python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'
161163
162164
163165
# -----------------------------------------------------------------

tests/lora/test_layers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,13 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
221221
)
222222

223223
axis_names = ("data", "model")
224+
devices = jax.devices()
224225
mesh_shape = (
225-
1, 1
226+
1, len(devices)
227+
# 1, 1
226228
) # TODO(xiowei): support multi-chip: mesh_shape = (1, len(jax.devices()))
227-
mesh = jax.make_mesh(mesh_shape, axis_names, devices=jax.devices())
229+
print(f'xw32 mesh_shape: {mesh_shape}')
230+
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
228231

229232
def create_column_parallel_packed_layer():
230233
# We first create a base linear layer, then a lora layer to wrap it.
@@ -281,7 +284,10 @@ def create_column_parallel_packed_layer():
281284
with torchax.default_env():
282285
# lora_linear.weight has type torchax.tensor.Tensor
283286
# BaseLinearLayerWithLoRA.weight property guarantees this.
284-
assert torch.equal(linear.weight, lora_linear.weight.to('cpu'))
287+
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
288+
# So the below check will fail.
289+
if len(devices) == 1:
290+
assert torch.equal(linear.weight.data, lora_linear.weight.to('cpu'))
285291

286292
max_num_batched_tokens = 8192
287293
max_batches = 256

0 commit comments

Comments
 (0)