Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an explicit value mapping instead of IRReplaceSizes() #107

Merged
merged 13 commits into from
Jun 22, 2020
171 changes: 118 additions & 53 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,66 @@ void testGPU_FusionExprEvalComplex() {
checkIntValue(&eval_context, tv6->axis(2)->rawExtent(), 127);
}

// Evaluate expressions post lowering
void testGPU_FusionExprEvalPostLower() {
Fusion fusion;
FusionGuard fg(&fusion);

// Create a non-trivial IR
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);

fusion.addInput(tv0);
fusion.addInput(tv1);

TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);

fusion.addOutput(tv3);

tv3->split(0, 4);

tv0->computeAt(tv3, 1);
tv1->computeAt(tv3, 1);

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(1)->parallelize(ParallelType::Unroll);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));

// Lower
GPULower gpulw(&fusion);
std::stringstream kernel;
gpulw.printKernel(kernel);

// 1. Create an evaluation context
EvaluationContext eval_context(&fusion);

// 2. Bind values
eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);
eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);

// 3. Evaluate and check result values
TORCH_CHECK(tv2->domain()->nDims() == 3);
checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);

TORCH_CHECK(tv3->domain()->nDims() == 3);
checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);

checkIntValue(&eval_context, bid_x, 2);
checkIntValue(&eval_context, tid_x, 128);
}

void testGPU_FusionSimpleArith() {
std::stringstream ss1, ss2;

Expand Down Expand Up @@ -729,51 +789,56 @@ void testGPU_FusionParser() {
prog.device_ = 0;
fuser::cuda::parseJitIR(g, &prog);

std::stringstream ref;
ref << "__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){\n"
<< " float T2[4];\n"
<< " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
<< " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n"
<< " T2[ i64 ]\n"
<< " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
<< " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
<< " }\n"
<< " } else { \n"
<< " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n"
<< " if ( ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
<< " T2[ i64 ]\n"
<< " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n"
<< " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
<< " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n"
<< " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
<< " = T2[ i65 ]\n"
<< " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
<< " }\n"
<< " } else { \n"
<< " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n"
<< " if ( ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n"
<< " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n"
<< " = T2[ i65 ]\n"
<< " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n"
<< " }\n"
<< " }\n"
<< " }\n"
<< "}\n";
// CONSIDER:
// 1. this can be moved to a dedicated "golden" file
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could work. Looks a bit too fuzzy to me ("foo bar" would compare equal to "foobar"), so it would be nice if it collapses spaces between identifiers (to a single space) instead of removing them completely, but in this case it may be good enough.

const std::string expected_kernel = R"(
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){
float T2[4];
if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
for(size_t i40 = 0; i40 < 4; ++i40 ) {
T2[ i40 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
}
} else {
for(size_t i40 = 0; i40 < 4; ++i40 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T2[ i40 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
}
}
}
if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
for(size_t i41 = 0; i41 < 4; ++i41 ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i41 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
}
} else {
for(size_t i41 = 0; i41 < 4; ++i41 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i41 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
}
}
}
}
)";

GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
if (ref.str().size() != cdg.str().size() ||
ref.str().compare(cdg.str()) != 0) {
std::stringstream actual_kernel;
actual_kernel << "\n";
gpulw.printKernel(actual_kernel);
if (expected_kernel.size() != actual_kernel.str().size() ||
expected_kernel.compare(actual_kernel.str()) != 0) {
std::cerr
<< " Codegen mismatch, codegen possibly changed, or is incorrect. "
<< " \n ========= REF ========= \n"
<< ref.str() << "\n========= RESULT ========== \n"
<< cdg.str() << "\n=================" << std::endl;
<< " \n ========= EXPECTED ========= \n"
<< expected_kernel << "\n========= ACTUAL ========== \n"
<< actual_kernel.str() << "\n=================" << std::endl;
TORCH_CHECK(false);
}
}
Expand Down Expand Up @@ -1200,10 +1265,10 @@ void testGPU_FusionAdvancedComputeAt() {
&prog, {t0}, {kernel_tv5, kernel_tv6});

GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::stringstream actual_kernel;
gpulw.printKernel(actual_kernel);

TORCH_CHECK(at::allclose(kernel_tv5, t5), cdg.str());
TORCH_CHECK(at::allclose(kernel_tv5, t5), actual_kernel.str());
TORCH_CHECK(at::allclose(kernel_tv6, t6));
}

Expand Down Expand Up @@ -1267,10 +1332,10 @@ void testGPU_FusionAdvancedComputeAt() {
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {kernel_tv3});

GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::stringstream actual_kernel;
gpulw.printKernel(actual_kernel);

TORCH_CHECK(at::allclose(kernel_tv3, t3), cdg.str());
TORCH_CHECK(at::allclose(kernel_tv3, t3), actual_kernel.str());
}

// Case 4
Expand Down Expand Up @@ -1347,10 +1412,10 @@ void testGPU_FusionAdvancedComputeAt() {
&prog, {t0, t1, t2, t3}, {kernel_tv6});

GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::stringstream actual_kernel;
gpulw.printKernel(actual_kernel);

TORCH_CHECK(at::allclose(kernel_tv6, t6), cdg.str());
TORCH_CHECK(at::allclose(kernel_tv6, t6), actual_kernel.str());
}
}

Expand Down Expand Up @@ -1447,10 +1512,10 @@ void testGPU_FusionScalarInputs() {
{kernel_tv4});

GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::stringstream actual_kernel;
gpulw.printKernel(actual_kernel);

TORCH_CHECK(at::allclose(kernel_tv4, t4), cdg.str());
TORCH_CHECK(at::allclose(kernel_tv4, t4), actual_kernel.str());
}

void testGPU_FusionLoopUnroll() {
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ namespace jit {
_(GPU_FusionExprEvalBindings) \
_(GPU_FusionExprEvalBasic) \
_(GPU_FusionExprEvalComplex) \
_(GPU_FusionExprEvalPostLower) \
_(GPU_FusionSimpleTypePromote) \
_(GPU_FusionMutator) \
_(GPU_FusionRegister) \
Expand Down
44 changes: 18 additions & 26 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,27 +321,19 @@ template void Val::dispatch(OptInDispatch*, Val*);
template void Expr::dispatch(OptInDispatch, Expr*);
template void Expr::dispatch(OptInDispatch*, Expr*);

template void Statement::constDispatch(
OptOutConstDispatch,
const Statement* const);
template void Statement::constDispatch(
OptOutConstDispatch*,
const Statement* const);
template void Val::constDispatch(OptOutConstDispatch, const Val* const);
template void Val::constDispatch(OptOutConstDispatch*, const Val* const);
template void Expr::constDispatch(OptOutConstDispatch, const Expr* const);
template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const);
template void Statement::constDispatch(OptOutConstDispatch, const Statement*);
template void Statement::constDispatch(OptOutConstDispatch*, const Statement*);
template void Val::constDispatch(OptOutConstDispatch, const Val*);
template void Val::constDispatch(OptOutConstDispatch*, const Val*);
template void Expr::constDispatch(OptOutConstDispatch, const Expr*);
template void Expr::constDispatch(OptOutConstDispatch*, const Expr*);

template void Statement::constDispatch(
OptInConstDispatch,
const Statement* const);
template void Statement::constDispatch(
OptInConstDispatch*,
const Statement* const);
template void Val::constDispatch(OptInConstDispatch, const Val* const);
template void Val::constDispatch(OptInConstDispatch*, const Val* const);
template void Expr::constDispatch(OptInConstDispatch, const Expr* const);
template void Expr::constDispatch(OptInConstDispatch*, const Expr* const);
template void Statement::constDispatch(OptInConstDispatch, const Statement*);
template void Statement::constDispatch(OptInConstDispatch*, const Statement*);
template void Val::constDispatch(OptInConstDispatch, const Val*);
template void Val::constDispatch(OptInConstDispatch*, const Val*);
template void Expr::constDispatch(OptInConstDispatch, const Expr*);
template void Expr::constDispatch(OptInConstDispatch*, const Expr*);

template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*);
template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*);
Expand Down Expand Up @@ -377,23 +369,23 @@ void OptInDispatch::handle(Val* v) {
Val::dispatch(this, v);
}

void OptOutConstDispatch::handle(const Statement* const s) {
void OptOutConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}
void OptOutConstDispatch::handle(const Expr* const e) {
void OptOutConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}
void OptOutConstDispatch::handle(const Val* const v) {
void OptOutConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}

void OptInConstDispatch::handle(const Statement* const s) {
void OptInConstDispatch::handle(const Statement* s) {
Statement::constDispatch(this, s);
}
void OptInConstDispatch::handle(const Expr* const e) {
void OptInConstDispatch::handle(const Expr* e) {
Expr::constDispatch(this, e);
}
void OptInConstDispatch::handle(const Val* const v) {
void OptInConstDispatch::handle(const Val* v) {
Val::constDispatch(this, v);
}

Expand Down
Loading