Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnaryOpType::Print which can be helpful for debugging #1878

Merged
merged 3 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ NVFUSER_DEFINE_UNARY_OP(relu, Relu)
NVFUSER_DEFINE_UNARY_OP(round, Round)
NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
NVFUSER_DEFINE_UNARY_OP(print, Print)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ TORCH_CUDA_CU_API TensorView* isposinf(TensorView*);
// isreal
TORCH_CUDA_CU_API Val* isreal(Val*);
TORCH_CUDA_CU_API TensorView* isreal(TensorView*);
// print
TORCH_CUDA_CU_API Val* print(Val*);
TORCH_CUDA_CU_API TensorView* print(TensorView*);

// Broadcasts inp based on bool vector. Size of broadcast bool vector should be
// the number of dims desired in the broadcasted tensor. This vector should be
Expand Down
100 changes: 100 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/helpers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,103 @@ __device__ inline int64_t readCycleCounter() {
__threadfence();
return clock64();
}

__device__ float print_impl(const char* name, float value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ double print_impl(const char* name, double value) {
printf(
"%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ int print_impl(const char* name, int value) {
printf(
"%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ int64_t print_impl(const char* name, int64_t value) {
printf(
"%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value,
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ bool print_impl(const char* name, bool value) {
printf(
"%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
value ? "true" : "false",
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ __half print_impl(const char* name, __half value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
__half2float(value),
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

__device__ __bfloat print_impl(const char* name, __bfloat value) {
printf(
"%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n",
name,
__bfloat2float(value),
(int)threadIdx.x,
(int)threadIdx.y,
(int)threadIdx.z,
(int)blockIdx.x,
(int)blockIdx.y,
(int)blockIdx.z);
return value;
}

#define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__))
38 changes: 38 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25170,6 +25170,44 @@ TEST_F(NVFuserTest, FusionIdGraphTrivialReduction_CUDA) {
}
}

TEST_F(NVFuserTest, FusionPrint_CUDA) {
auto dtypes = {
at::kFloat,
at::kDouble,
at::kHalf,
at::kBFloat16,
at::kInt,
at::kLong,
at::kBool};
for (auto dtype : dtypes) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

auto tv0 = makeSymbolicTensor(1, aten_to_data_type(dtype));
fusion->addInput(tv0);
auto tv1 = print(tv0);
auto tv2 = sin(tv1);
fusion->addOutput(tv2);

// There is no way to check if anything is printed to the console, but we
// can validate that when print exist, compilation and computation are not
// broken.
auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::Tensor t0 = at::arange(2, options).to(dtype);

FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs({t0});
Copy link
Collaborator Author

@zasdfgbnm zasdfgbnm Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will do the following print when running the test, hope this doesn't matter.

T3[0] = 0.000000 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T3[0] = 1.000000 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T3[0] = 0.000000 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T3[0] = 1.000000 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T4[0] = 0.000000 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T4[0] = 1.000000 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T4[0] = 0.000000 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T4[0] = 1.000000 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T4[0] = 0 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T4[0] = 1 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T4[0] = 0 @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T4[0] = 1 @ threadIdx=(1,0,0), blockIdx=(0,0,0)
T4[0] = false @ threadIdx=(0,0,0), blockIdx=(0,0,0)
T4[0] = true @ threadIdx=(1,0,0), blockIdx=(0,0,0)


testValidate(
executor_cache.fusion(),
cg_outputs,
{t0},
{t0.sin()},
__LINE__,
__FILE__);
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ bool needFloatSuffix(UnaryOpType t) {
case UnaryOpType::IsNegInf:
case UnaryOpType::IsPosInf:
case UnaryOpType::IsReal:
case UnaryOpType::Print:
return false;
default:
return true;
Expand Down Expand Up @@ -432,6 +433,8 @@ static const char* unary_op_type2string(UnaryOpType t) {
return "neg";
case UnaryOpType::Not:
return "not";
case UnaryOpType::Print:
return "print";
case UnaryOpType::RandLike:
return "randLike";
case UnaryOpType::Reciprocal:
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ enum class UnaryOpType {
Tanh,
Trunc,

// Tools to help debugging
Print,

// Might be a bitwise operator or boolean operator.
Not,

Expand Down