Skip to content

Commit

Permalink
Improve the comments at the beginning of index_compute.h (#1946)
Browse files Browse the repository at this point in the history
I just started to learn indexing, and the comment at the beginning of index_compute.h does not look good...
  • Loading branch information
zasdfgbnm authored Sep 5, 2022
1 parent f7bc341 commit a3ecb33
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,40 @@
* indices (based on input indices) that match the root dimension.
*
* For example with GLOBAL tensor:
* TV[I, J]
* TV[Io, Ii{4}, J] = TV.split(I, factor=4)
* TV[I, K]
* TV[Io, Ii{4}, K] = TV.split(I, factor=4)
* ALLOC: NONE
* INDEX: indexCompute {i, j, k} -> {i * 4 + j, k}
* FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k}
* FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k}
* PREDICATE: {i * 4 + j, k} -> i * 4 + j < I
*
*
* For example with SHARED tensor:
*
* global_TV[I, J]
* global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
* global_TV[I, K]
* global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4)
* smem_TV.compute_at(global_TV, 1)
* global_TV.parallelize(1, threadIDx.x)
*
* ALLOC: alloc(smem_TV, 4 x J)
* ALLOC: alloc(smem_TV, 4 x K)
* INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k}
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k}
* FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k}
* PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if
* global
*
*
* For example with LOCAL tensor:
* global_TV[I, J, K]
* global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4)
* reg_TV.compute_at(global_TV, 1)
* global_TV[I, K, L]
* global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4)
* reg_TV.compute_at(global_TV, 2)
* global_TV.parallelize(1, threadIDx.x)
* global_TV{i, j, k, l} -> { i * 4 + j, k, l }
* global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K + k * K + l}
* global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l}
*
* ALLOC: alloc(reg_TV, J x K)
* ALLOC: alloc(reg_TV, K x L)
* INDEX: {k, l} -> {k, l}
* FLATTENED_INDEX: {k, l} -> {k * J + l}
* PREDICATE: i * 4 + j < I && k < J && l < K -> // Same as if global
* FLATTENED_INDEX: {k, l} -> {k * L + l}
* PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global
*
* These indices can then be flattened later based on strides.
*/
Expand Down

0 comments on commit a3ecb33

Please sign in to comment.