From a3ecb339442131f87842eb56955e4f17c544e99f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 5 Sep 2022 05:10:17 -0700 Subject: [PATCH] Improve the comments at the beginning of index_compute.h (#1946) I just started to learn indexing, and the comment at the beginning of index_compute.h does not look good... --- torch/csrc/jit/codegen/cuda/index_compute.h | 28 ++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index 5d8703c2e970e..2ecbaa8352a63 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -17,40 +17,40 @@ * indices (based on input indices) that match the root dimension. * * For example with GLOBAL tensor: - * TV[I, J] - * TV[Io, Ii{4}, J] = TV.split(I, factor=4) + * TV[I, K] + * TV[Io, Ii{4}, K] = TV.split(I, factor=4) * ALLOC: NONE * INDEX: indexCompute {i, j, k} -> {i * 4 + j, k} - * FLATTENED_INDEX: {i * 4 + j, k} -> {i * 4 + j * J + k} + * FLATTENED_INDEX: {i * 4 + j, k} -> {(i * 4 + j) * K + k} * PREDICATE: {i * 4 + j, k} -> i * 4 + j < I * * * For example with SHARED tensor: * - * global_TV[I, J] - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) + * global_TV[I, K] + * global_TV[Io, Ii{4}, K] = global_TV.split(I, factor=4) * smem_TV.compute_at(global_TV, 1) * global_TV.parallelize(1, threadIDx.x) * - * ALLOC: alloc(smem_TV, 4 x J) + * ALLOC: alloc(smem_TV, 4 x K) * INDEX: indexCompute(smem_TV, {threadIdx.x, k}) -> {threadIdx.x, k} - * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {threadIdx.x * 4 + j * J + k} + * FLATTENED_INDEX: {threadIdx.x * 4 + j, k} -> {(threadIdx.x * 4 + j) * K + k} * PREDICATE: {threadIdx.x * 4 + j, k} -> threadIdx.x * 4 + j < I // Same as if * global * * * For example with LOCAL tensor: - * global_TV[I, J, K] - * global_TV[Io, Ii{4}, J] = global_TV.split(I, factor=4) - * reg_TV.compute_at(global_TV, 1) + * global_TV[I, K, L] + * global_TV[Io, Ii{4}, K, L] = global_TV.split(I, factor=4) + * reg_TV.compute_at(global_TV, 2) * global_TV.parallelize(1, threadIDx.x) * global_TV{i, j, k, l} -> { i * 4 + j, k, l } - * global_TV{ i * 4 + j, k, l } -> { i * 4 + j * J * K + k * K + l} + * global_TV{ i * 4 + j, k, l } -> { (i * 4 + j) * K * L + k * L + l} * - * ALLOC: alloc(reg_TV, J x K) + * ALLOC: alloc(reg_TV, K x L) * INDEX: {k, l} -> {k, l} - * FLATTENED_INDEX: {k, l} -> {k * J + l} - * PREDICATE: i * 4 + j < I && k < J && l < K -> // Same as if global + * FLATTENED_INDEX: {k, l} -> {k * L + l} + * PREDICATE: i * 4 + j < I && k < K && l < L -> // Same as if global * * These indices can then be flattened later based on strides. */