From 87c5e5baa45000be34b09524dfe249941c58d94d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 29 Jul 2022 16:53:10 -0700 Subject: [PATCH 1/3] Add UnaryOpType::Print which can be helpful for debugging silent wrong results --- torch/csrc/jit/codegen/cuda/arith.cpp | 1 + torch/csrc/jit/codegen/cuda/arith.h | 3 ++ .../csrc/jit/codegen/cuda/runtime/helpers.cu | 37 ++++++++++++++++++ torch/csrc/jit/codegen/cuda/test/test_gpu.cpp | 38 +++++++++++++++++++ torch/csrc/jit/codegen/cuda/type.cpp | 3 ++ torch/csrc/jit/codegen/cuda/type.h | 3 ++ 6 files changed, 85 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index ca1782f08d66a..790894d01b817 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -466,6 +466,7 @@ NVFUSER_DEFINE_UNARY_OP(relu, Relu) NVFUSER_DEFINE_UNARY_OP(round, Round) NVFUSER_DEFINE_UNARY_OP(silu, Silu) NVFUSER_DEFINE_UNARY_OP(trunc, Trunc) +NVFUSER_DEFINE_UNARY_OP(print, Print) #undef NVFUSER_DEFINE_UNARY_OP Val* randlike(Val* v) { diff --git a/torch/csrc/jit/codegen/cuda/arith.h b/torch/csrc/jit/codegen/cuda/arith.h index c13b610a3bc39..c6d0011c03049 100644 --- a/torch/csrc/jit/codegen/cuda/arith.h +++ b/torch/csrc/jit/codegen/cuda/arith.h @@ -245,6 +245,9 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*); // isreal TORCH_CUDA_CU_API Val* isreal(Val*); TORCH_CUDA_CU_API TensorView* isreal(TensorView*); +// print +TORCH_CUDA_CU_API Val* print(Val*); +TORCH_CUDA_CU_API TensorView* print(TensorView*); // Broadcasts inp based on bool vector. Size of broadcast bool vector should be // the number of dims desired in the broadcasted tensor. This vector should be diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index b23af5b6d93dd..5dc6de8c3ae3d 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -528,3 +528,40 @@ __device__ inline int64_t readCycleCounter() { __threadfence(); return clock64(); } + +float print_impl(const char *name, float value) { + printf("%s = %f\n", name, value); + return value; +} + +double print_impl(const char *name, double value) { + printf("%s = %lf\n", name, value); + return value; +} + +int print_impl(const char *name, int value) { + printf("%s = %d\n", name, value); + return value; +} + +int64_t print_impl(const char *name, int64_t value) { + printf("%s = %ld\n", name, value); + return value; +} + +bool print_impl(const char *name, bool value) { + printf("%s = %s\n", name, value ? "true" : "false"); + return value; +} + +__half print_impl(const char *name, __half value) { + printf("%s = %f\n", name, __half2float(value)); + return value; +} + +__bfloat print_impl(const char *name, __bfloat value) { + printf("%s = %f\n", name, __bfloat2float(value)); + return value; +} + +#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__)) diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp index 0d05adf4ba55b..1b235c1263628 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu.cpp @@ -25170,6 +25170,44 @@ TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) { } } +TEST_F(NVFuserTest, FusionPrint_CUDA) { + auto dtypes = { + at::kFloat, + at::kDouble, + at::kHalf, + at::kBFloat16, + at::kInt, + at::kLong, + at::kBool}; + for (auto dtype : dtypes) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype)); + fusion->addInput(tv0); + auto tv1 = print(tv0); + auto tv2 = sin(tv1); + fusion->addOutput(tv2); + + // There is no way to check if anything is printed to the console, but we + // can validate that when print exist, compilation and computation are not + // broken. + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + at::Tensor t0 = at::arange(2, options).to(dtype); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); + + testValidate( + executor_cache.fusion(), + cg_outputs, + {t0}, + {t0.sin()}, + __LINE__, + __FILE__); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index a3f0170f704c5..70619f28d98e9 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -376,6 +376,7 @@ bool needFloatSuffix(UnaryOpType t) { case UnaryOpType::IsNegInf: case UnaryOpType::IsPosInf: case UnaryOpType::IsReal: + case UnaryOpType::Print: return false; default: return true; @@ -432,6 +433,8 @@ static const char* unary_op_type2string(UnaryOpType t) { return "neg"; case UnaryOpType::Not: return "not"; + case UnaryOpType::Print: + return "print"; case UnaryOpType::RandLike: return "randLike"; case UnaryOpType::Reciprocal: diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 9d8b5cb99a6a8..f55ae836e7d24 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -179,6 +179,9 @@ enum class UnaryOpType { Tanh, Trunc, + // Tools to help debugging + Print, + // Might be a bitwise operator or boolean operator. Not, From 26a19c462154ce24f5ab9413d45ab5453cc0d689 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 29 Jul 2022 17:15:29 -0700 Subject: [PATCH 2/3] print thread --- .../csrc/jit/codegen/cuda/runtime/helpers.cu | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index 5dc6de8c3ae3d..55602b153ac97 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -529,38 +529,56 @@ __device__ inline int64_t readCycleCounter() { return clock64(); } -float print_impl(const char *name, float value) { - printf("%s = %f\n", name, value); +__device__ void printThread() { + printf("threadIdx.x = %d, threadIdx.y = %d, threadIdx.z = %d, " + "blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d\n", + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); +} + +__device__ float print_impl(const char *name, float value) { + printf("%s = %f @ ", name, value); + printThread(); return value; } -double print_impl(const char *name, double value) { - printf("%s = %lf\n", name, value); +__device__ double print_impl(const char *name, double value) { + printf("%s = %lf @ ", name, value); + printThread(); return value; } -int print_impl(const char *name, int value) { - printf("%s = %d\n", name, value); +__device__ int print_impl(const char *name, int value) { + printf("%s = %d @ ", name, value); + printThread(); return value; } -int64_t print_impl(const char *name, int64_t value) { - printf("%s = %ld\n", name, value); +__device__ int64_t print_impl(const char *name, int64_t value) { + printf("%s = %ld @ ", name, value); + printThread(); return value; } -bool print_impl(const char *name, bool value) { - printf("%s = %s\n", name, value ? "true" : "false"); +__device__ bool print_impl(const char *name, bool value) { + printf("%s = %s @ ", name, value ? "true" : "false"); + printThread(); return value; } -__half print_impl(const char *name, __half value) { - printf("%s = %f\n", name, __half2float(value)); +__device__ __half print_impl(const char *name, __half value) { + printf("%s = %f @ ", name, __half2float(value)); + printThread(); return value; } -__bfloat print_impl(const char *name, __bfloat value) { - printf("%s = %f\n", name, __bfloat2float(value)); +__device__ __bfloat print_impl(const char *name, __bfloat value) { + printf("%s = %f @ ", name, __bfloat2float(value)); + printThread(); return value; } From 35633f9273280352bef4ce568459306a285b7204 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 29 Jul 2022 17:22:53 -0700 Subject: [PATCH 3/3] print thread --- .../csrc/jit/codegen/cuda/runtime/helpers.cu | 109 +++++++++++++----- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu index 55602b153ac97..d198844933a32 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/helpers.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/helpers.cu @@ -529,56 +529,101 @@ __device__ inline int64_t readCycleCounter() { return clock64(); } -__device__ void printThread() { - printf("threadIdx.x = %d, threadIdx.y = %d, threadIdx.z = %d, " - "blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d\n", - (int)threadIdx.x, - (int)threadIdx.y, - (int)threadIdx.z, - (int)blockIdx.x, - (int)blockIdx.y, - (int)blockIdx.z); -} - -__device__ float print_impl(const char *name, float value) { - printf("%s = %f @ ", name, value); - printThread(); +__device__ float print_impl(const char* name, float value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ double print_impl(const char *name, double value) { - printf("%s = %lf @ ", name, value); - printThread(); +__device__ double print_impl(const char* name, double value) { + printf( + "%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ int print_impl(const char *name, int value) { - printf("%s = %d @ ", name, value); - printThread(); +__device__ int print_impl(const char* name, int value) { + printf( + "%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ int64_t print_impl(const char *name, int64_t value) { - printf("%s = %ld @ ", name, value); - printThread(); +__device__ int64_t print_impl(const char* name, int64_t value) { + printf( + "%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value, + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ bool print_impl(const char *name, bool value) { - printf("%s = %s @ ", name, value ? "true" : "false"); - printThread(); +__device__ bool print_impl(const char* name, bool value) { + printf( + "%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + value ? "true" : "false", + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ __half print_impl(const char *name, __half value) { - printf("%s = %f @ ", name, __half2float(value)); - printThread(); +__device__ __half print_impl(const char* name, __half value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + __half2float(value), + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; } -__device__ __bfloat print_impl(const char *name, __bfloat value) { - printf("%s = %f @ ", name, __bfloat2float(value)); - printThread(); +__device__ __bfloat print_impl(const char* name, __bfloat value) { + printf( + "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", + name, + __bfloat2float(value), + (int)threadIdx.x, + (int)threadIdx.y, + (int)threadIdx.z, + (int)blockIdx.x, + (int)blockIdx.y, + (int)blockIdx.z); return value; }