|
271 | 271 | "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", |
272 | 272 | "grid_spec = pltpu.PrefetchScalarGridSpec(\n", |
273 | 273 | " num_scalar_prefetch=0,\n", |
274 | | - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", |
| 274 | + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", |
275 | 275 | " in_specs=[\n", |
276 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 276 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
277 | 277 | " ],\n", |
278 | | - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 278 | + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
279 | 279 | " scratch_shapes=(\n", |
280 | 280 | " # We allocate DMA semaphores in scratch memory.\n", |
281 | 281 | " [pltpu.SemaphoreType.DMA] * 2\n", |
|
420 | 420 | "grid_spec = pltpu.PrefetchScalarGridSpec(\n", |
421 | 421 | " num_scalar_prefetch=0,\n", |
422 | 422 | " in_specs=[\n", |
423 | | - " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", |
424 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 423 | + " # MemorySpace.ANY will (usually) place the tensor in HBM.\n", |
| 424 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
425 | 425 | " ],\n", |
426 | | - " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 426 | + " out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
427 | 427 | " scratch_shapes=(\n", |
428 | 428 | " # DMA semaphores are allocated in scratch memory.\n", |
429 | 429 | " # We allocated one semaphore for a local HBM-VMEM copy,\n", |
|
569 | 569 | "kernel = pl.pallas_call(\n", |
570 | 570 | " example_kernel,\n", |
571 | 571 | " ...,\n", |
572 | | - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", |
| 572 | + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", |
573 | 573 | ")\n", |
574 | 574 | "```" |
575 | 575 | ] |
|
809 | 809 | " num_scalar_prefetch=0,\n", |
810 | 810 | " in_specs=[\n", |
811 | 811 | " # Our input lives in VMEM\n", |
812 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", |
| 812 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", |
813 | 813 | " ],\n", |
814 | 814 | " out_specs=[\n", |
815 | 815 | " # Our output lives in VMEM\n", |
816 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", |
| 816 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", |
817 | 817 | " # Our double-buffer lives in HBM\n", |
818 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 818 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
819 | 819 | " ],\n", |
820 | 820 | " grid=(num_devices,),\n", |
821 | 821 | " scratch_shapes=(\n", |
|
829 | 829 | " all_reduce_kernel,\n", |
830 | 830 | " out_shape=out_shape,\n", |
831 | 831 | " grid_spec=grid_spec,\n", |
832 | | - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", |
| 832 | + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", |
833 | 833 | ")\n", |
834 | 834 | "\n", |
835 | 835 | "pallas_result = jax.jit(\n", |
|
1146 | 1146 | "grid_spec = pltpu.PrefetchScalarGridSpec(\n", |
1147 | 1147 | " num_scalar_prefetch=0,\n", |
1148 | 1148 | " in_specs=[\n", |
1149 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", |
| 1149 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", |
1150 | 1150 | " ],\n", |
1151 | 1151 | " out_specs=[\n", |
1152 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", |
1153 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 1152 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.VMEM),\n", |
| 1153 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
1154 | 1154 | " ],\n", |
1155 | 1155 | " grid=(num_devices, 2),\n", |
1156 | 1156 | " scratch_shapes=(\n", |
|
1169 | 1169 | " reduce_scatter_kernel,\n", |
1170 | 1170 | " out_shape=out_shape,\n", |
1171 | 1171 | " grid_spec=grid_spec,\n", |
1172 | | - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", |
| 1172 | + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", |
1173 | 1173 | " )(input_arr)[0]\n", |
1174 | 1174 | "\n", |
1175 | 1175 | "\n", |
|
1307 | 1307 | "\n", |
1308 | 1308 | "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", |
1309 | 1309 | "\n", |
1310 | | - "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", |
| 1310 | + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=MemorySpace.ANY`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", |
1311 | 1311 | "\n", |
1312 | 1312 | "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", |
1313 | 1313 | "\n", |
|
1408 | 1408 | "inner_block_spec = pl.BlockSpec(\n", |
1409 | 1409 | " index_map=lambda i, j: (i, j),\n", |
1410 | 1410 | " block_shape=inner_block_size,\n", |
1411 | | - " memory_space=pltpu.TPUMemorySpace.ANY,\n", |
| 1411 | + " memory_space=pltpu.MemorySpace.ANY,\n", |
1412 | 1412 | ")\n", |
1413 | 1413 | "\n", |
1414 | 1414 | "\n", |
|
1590 | 1590 | "grid_spec = pltpu.PrefetchScalarGridSpec(\n", |
1591 | 1591 | " num_scalar_prefetch=0,\n", |
1592 | 1592 | " in_specs=[\n", |
1593 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 1593 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
1594 | 1594 | " ],\n", |
1595 | 1595 | " out_specs=[\n", |
1596 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
1597 | | - " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", |
| 1596 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
| 1597 | + " pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY),\n", |
1598 | 1598 | " ],\n", |
1599 | 1599 | " grid=(num_devices, 2),\n", |
1600 | 1600 | " scratch_shapes=(\n", |
|
1612 | 1612 | " reduce_scatter_kernel,\n", |
1613 | 1613 | " out_shape=out_shape,\n", |
1614 | 1614 | " grid_spec=grid_spec,\n", |
1615 | | - " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", |
| 1615 | + " compiler_params=pltpu.CompilerParams(collective_id=0),\n", |
1616 | 1616 | " )(input_arr)[0]\n", |
1617 | 1617 | "\n", |
1618 | 1618 | "\n", |
|
0 commit comments