Skip to content

Commit 0ce9633

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic] Removed the TPU prefix from TPUCompilerParams and TPUMemorySpace
All TPU-specific APIs are always used qualified, e.g. `pltpu.TPUCompilerParams`, so the prefix is redundant. PiperOrigin-RevId: 765167675
1 parent 94037a8 commit 0ce9633

33 files changed

+209
-183
lines changed

docs/pallas/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ Remember to align the itemized text with the first line of an item within a list
1818
* {class}`jax.experimental.pallas.triton.TritonCompilerParams` has been
1919
renamed to {class}`jax.experimental.pallas.triton.CompilerParams`. The
2020
old name is deprecated and will be removed in a future release.
21+
* {class}`jax.experimental.pallas.tpu.TPUCompilerParams`
22+
and {class}`jax.experimental.pallas.tpu.TPUMemorySpace` have been
23+
renamed to {class}`jax.experimental.pallas.tpu.CompilerParams`
24+
and {class}`jax.experimental.pallas.tpu.MemorySpace`. The
25+
old names are deprecated and will be removed in a future release.
2126

2227
## Released with jax 0.6.1
2328

docs/pallas/quickstart.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@
280280
"metadata": {},
281281
"source": [
282282
"TPUs distinguish between vector and scalar memory spaces and in this case the\n",
283-
"output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
283+
"output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is\n",
284284
"a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.\n",
285285
"To call the above kernel on TPU, run:"
286286
]
@@ -297,7 +297,7 @@
297297
"\n",
298298
"def iota(size: int):\n",
299299
" return pl.pallas_call(iota_kernel,\n",
300-
" out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n",
300+
" out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),\n",
301301
" out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
302302
" grid=(size,))()\n",
303303
"iota(8)"

docs/pallas/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ iota(8)
186186
```
187187

188188
TPUs distinguish between vector and scalar memory spaces and in this case the
189-
output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
189+
output must be placed in scalar memory (`MemorySpace.SMEM`) since `i` is
190190
a scalar. For more details read {ref}`tpu_and_its_memory_spaces`.
191191
To call the above kernel on TPU, run:
192192

@@ -196,7 +196,7 @@ from jax.experimental.pallas import tpu as pltpu
196196
197197
def iota(size: int):
198198
return pl.pallas_call(iota_kernel,
199-
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
199+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.SMEM),
200200
out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
201201
grid=(size,))()
202202
iota(8)

docs/pallas/tpu/details.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ grid axes over cores. This is an opt-in procedure. To allow that,
170170
..
171171
pallas_call(
172172
...,
173-
compiler_params=pltpu.TPUCompilerParams(
173+
compiler_params=pltpu.CompilerParams(
174174
dimension_semantics=["parallel", "parallel", "arbitrary"]
175175
),
176176
)

docs/pallas/tpu/distributed.ipynb

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@
271271
"out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n",
272272
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
273273
" num_scalar_prefetch=0,\n",
274-
" # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n",
274+
" # MemorySpace.ANY will (usually) place the tensor in HBM.\n",
275275
" in_specs=[\n",
276-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
276+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
277277
" ],\n",
278-
" out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
278+
" out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
279279
" scratch_shapes=(\n",
280280
" # We allocate DMA semaphores in scratch memory.\n",
281281
" [pltpu.SemaphoreType.DMA] * 2\n",
@@ -420,10 +420,10 @@
420420
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
421421
" num_scalar_prefetch=0,\n",
422422
" in_specs=[\n",
423-
" # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n",
424-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
423+
" # MemorySpace.ANY will (usually) place the tensor in HBM.\n",
424+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
425425
" ],\n",
426-
" out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
426+
" out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
427427
" scratch_shapes=(\n",
428428
" # DMA semaphores are allocated in scratch memory.\n",
429429
" # We allocated one semaphore for a local HBM-VMEM copy,\n",
@@ -569,7 +569,7 @@
569569
"kernel = pl.pallas_call(\n",
570570
" example_kernel,\n",
571571
" ...,\n",
572-
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
572+
" compiler_params=pltpu.CompilerParams(collective_id=0),\n",
573573
")\n",
574574
"```"
575575
]
@@ -809,13 +809,13 @@
809809
" num_scalar_prefetch=0,\n",
810810
" in_specs=[\n",
811811
" # Our input lives in VMEM\n",
812-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
812+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
813813
" ],\n",
814814
" out_specs=[\n",
815815
" # Our output lives in VMEM\n",
816-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
816+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
817817
" # Our double-buffer lives in HBM\n",
818-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
818+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
819819
" ],\n",
820820
" grid=(num_devices,),\n",
821821
" scratch_shapes=(\n",
@@ -829,7 +829,7 @@
829829
" all_reduce_kernel,\n",
830830
" out_shape=out_shape,\n",
831831
" grid_spec=grid_spec,\n",
832-
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
832+
" compiler_params=pltpu.CompilerParams(collective_id=0),\n",
833833
")\n",
834834
"\n",
835835
"pallas_result = jax.jit(\n",
@@ -1146,11 +1146,11 @@
11461146
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
11471147
" num_scalar_prefetch=0,\n",
11481148
" in_specs=[\n",
1149-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
1149+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
11501150
" ],\n",
11511151
" out_specs=[\n",
1152-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
1153-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
1152+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n",
1153+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
11541154
" ],\n",
11551155
" grid=(num_devices, 2),\n",
11561156
" scratch_shapes=(\n",
@@ -1169,7 +1169,7 @@
11691169
" reduce_scatter_kernel,\n",
11701170
" out_shape=out_shape,\n",
11711171
" grid_spec=grid_spec,\n",
1172-
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
1172+
" compiler_params=pltpu.CompilerParams(collective_id=0),\n",
11731173
" )(input_arr)[0]\n",
11741174
"\n",
11751175
"\n",
@@ -1307,7 +1307,7 @@
13071307
"\n",
13081308
"In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n",
13091309
"\n",
1310-
"We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n",
1310+
"We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n",
13111311
"\n",
13121312
"In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n",
13131313
"\n",
@@ -1408,7 +1408,7 @@
14081408
"inner_block_spec = pl.BlockSpec(\n",
14091409
" index_map=lambda i, j: (i, j),\n",
14101410
" block_shape=inner_block_size,\n",
1411-
" memory_space=pltpu.TPUMemorySpace.ANY,\n",
1411+
" memory_space=pltpu.MemorySpace.ANY,\n",
14121412
")\n",
14131413
"\n",
14141414
"\n",
@@ -1590,11 +1590,11 @@
15901590
"grid_spec = pltpu.PrefetchScalarGridSpec(\n",
15911591
" num_scalar_prefetch=0,\n",
15921592
" in_specs=[\n",
1593-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
1593+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
15941594
" ],\n",
15951595
" out_specs=[\n",
1596-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
1597-
" pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
1596+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
1597+
" pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n",
15981598
" ],\n",
15991599
" grid=(num_devices, 2),\n",
16001600
" scratch_shapes=(\n",
@@ -1612,7 +1612,7 @@
16121612
" reduce_scatter_kernel,\n",
16131613
" out_shape=out_shape,\n",
16141614
" grid_spec=grid_spec,\n",
1615-
" compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
1615+
" compiler_params=pltpu.CompilerParams(collective_id=0),\n",
16161616
" )(input_arr)[0]\n",
16171617
"\n",
16181618
"\n",

docs/pallas/tpu/distributed.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,11 @@ def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
233233
out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
234234
grid_spec = pltpu.PrefetchScalarGridSpec(
235235
num_scalar_prefetch=0,
236-
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
236+
# MemorySpace.ANY will (usually) place the tensor in HBM.
237237
in_specs=[
238-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
238+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
239239
],
240-
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
240+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
241241
scratch_shapes=(
242242
# We allocate DMA semaphores in scratch memory.
243243
[pltpu.SemaphoreType.DMA] * 2
@@ -356,10 +356,10 @@ out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
356356
grid_spec = pltpu.PrefetchScalarGridSpec(
357357
num_scalar_prefetch=0,
358358
in_specs=[
359-
# TPUMemorySpace.ANY will (usually) place the tensor in HBM.
360-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
359+
# MemorySpace.ANY will (usually) place the tensor in HBM.
360+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
361361
],
362-
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
362+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
363363
scratch_shapes=(
364364
# DMA semaphores are allocated in scratch memory.
365365
# We allocated one semaphore for a local HBM-VMEM copy,
@@ -491,7 +491,7 @@ When using barrier semaphores, the `collective_id` compiler parameter must be pa
491491
kernel = pl.pallas_call(
492492
example_kernel,
493493
...,
494-
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
494+
compiler_params=pltpu.CompilerParams(collective_id=0),
495495
)
496496
```
497497

@@ -703,13 +703,13 @@ grid_spec = pltpu.PrefetchScalarGridSpec(
703703
num_scalar_prefetch=0,
704704
in_specs=[
705705
# Our input lives in VMEM
706-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
706+
pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
707707
],
708708
out_specs=[
709709
# Our output lives in VMEM
710-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
710+
pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
711711
# Our double-buffer lives in HBM
712-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
712+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
713713
],
714714
grid=(num_devices,),
715715
scratch_shapes=(
@@ -723,7 +723,7 @@ kernel = pl.pallas_call(
723723
all_reduce_kernel,
724724
out_shape=out_shape,
725725
grid_spec=grid_spec,
726-
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
726+
compiler_params=pltpu.CompilerParams(collective_id=0),
727727
)
728728
729729
pallas_result = jax.jit(
@@ -1019,11 +1019,11 @@ out_shape = (
10191019
grid_spec = pltpu.PrefetchScalarGridSpec(
10201020
num_scalar_prefetch=0,
10211021
in_specs=[
1022-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
1022+
pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
10231023
],
10241024
out_specs=[
1025-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
1026-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
1025+
pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),
1026+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
10271027
],
10281028
grid=(num_devices, 2),
10291029
scratch_shapes=(
@@ -1042,7 +1042,7 @@ def pallas_reduce_scatter(input_arr):
10421042
reduce_scatter_kernel,
10431043
out_shape=out_shape,
10441044
grid_spec=grid_spec,
1045-
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
1045+
compiler_params=pltpu.CompilerParams(collective_id=0),
10461046
)(input_arr)[0]
10471047
10481048
@@ -1148,7 +1148,7 @@ pl.pallas_call(
11481148

11491149
In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.
11501150

1151-
We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.
1151+
We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.
11521152

11531153
In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:
11541154

@@ -1242,7 +1242,7 @@ inner_grid = (
12421242
inner_block_spec = pl.BlockSpec(
12431243
index_map=lambda i, j: (i, j),
12441244
block_shape=inner_block_size,
1245-
memory_space=pltpu.TPUMemorySpace.ANY,
1245+
memory_space=pltpu.MemorySpace.ANY,
12461246
)
12471247
12481248
@@ -1424,11 +1424,11 @@ out_shape = (
14241424
grid_spec = pltpu.PrefetchScalarGridSpec(
14251425
num_scalar_prefetch=0,
14261426
in_specs=[
1427-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
1427+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
14281428
],
14291429
out_specs=[
1430-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
1431-
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
1430+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
1431+
pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),
14321432
],
14331433
grid=(num_devices, 2),
14341434
scratch_shapes=(
@@ -1446,7 +1446,7 @@ def pallas_reduce_scatter(input_arr):
14461446
reduce_scatter_kernel,
14471447
out_shape=out_shape,
14481448
grid_spec=grid_spec,
1449-
compiler_params=pltpu.TPUCompilerParams(collective_id=0),
1449+
compiler_params=pltpu.CompilerParams(collective_id=0),
14501450
)(input_arr)[0]
14511451
14521452

docs/pallas/tpu/matmul.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@
210210
" pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n",
211211
" out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n",
212212
" grid=(m // bm, n // bn, k // bk),\n",
213-
" compiler_params=pltpu.TPUCompilerParams(\n",
213+
" compiler_params=pltpu.CompilerParams(\n",
214214
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
215215
" )(x, y)"
216216
]
@@ -466,7 +466,7 @@
466466
" grid=(m // bm, n // bn, k // bk),\n",
467467
" ),\n",
468468
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
469-
" compiler_params=pltpu.TPUCompilerParams(\n",
469+
" compiler_params=pltpu.CompilerParams(\n",
470470
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
471471
" )(x, y)"
472472
]
@@ -741,7 +741,7 @@
741741
" grid=(m // bm, n // bn, k // bk),\n",
742742
" ),\n",
743743
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
744-
" compiler_params=pltpu.TPUCompilerParams(\n",
744+
" compiler_params=pltpu.CompilerParams(\n",
745745
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
746746
" )(x, y)"
747747
]
@@ -929,7 +929,7 @@
929929
" grid=(m // bm, n // bn, k // bk),\n",
930930
" ),\n",
931931
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
932-
" compiler_params=pltpu.TPUCompilerParams(\n",
932+
" compiler_params=pltpu.CompilerParams(\n",
933933
" dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
934934
" )(x, y)"
935935
]

docs/pallas/tpu/matmul.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def matmul(
167167
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
168168
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
169169
grid=(m // bm, n // bn, k // bk),
170-
compiler_params=pltpu.TPUCompilerParams(
170+
compiler_params=pltpu.CompilerParams(
171171
dimension_semantics=("parallel", "parallel", "arbitrary")),
172172
)(x, y)
173173
```
@@ -321,7 +321,7 @@ def matmul(
321321
grid=(m // bm, n // bn, k // bk),
322322
),
323323
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
324-
compiler_params=pltpu.TPUCompilerParams(
324+
compiler_params=pltpu.CompilerParams(
325325
dimension_semantics=("parallel", "parallel", "arbitrary")),
326326
)(x, y)
327327
```
@@ -489,7 +489,7 @@ def matmul(
489489
grid=(m // bm, n // bn, k // bk),
490490
),
491491
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
492-
compiler_params=pltpu.TPUCompilerParams(
492+
compiler_params=pltpu.CompilerParams(
493493
dimension_semantics=("parallel", "parallel", "arbitrary")),
494494
)(x, y)
495495
```
@@ -613,7 +613,7 @@ def matmul(
613613
grid=(m // bm, n // bn, k // bk),
614614
),
615615
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
616-
compiler_params=pltpu.TPUCompilerParams(
616+
compiler_params=pltpu.CompilerParams(
617617
dimension_semantics=("parallel", "parallel", "arbitrary")),
618618
)(x, y)
619619
```

0 commit comments

Comments
 (0)