From 56c8f12bcc4c575883baf706567e513d4c5a134b Mon Sep 17 00:00:00 2001 From: Lemo Date: Wed, 17 Jun 2020 21:45:59 -0700 Subject: [PATCH 01/12] Minor cleanup --- torch/csrc/jit/codegen/cuda/dispatch.cpp | 36 +++++----- torch/csrc/jit/codegen/cuda/dispatch.h | 88 ++++++++++++------------ 2 files changed, 62 insertions(+), 62 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index a3a55220e534ea..24f81030cfb2d2 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -323,25 +323,25 @@ template void Expr::dispatch(OptInDispatch*, Expr*); template void Statement::constDispatch( OptOutConstDispatch, - const Statement* const); + const Statement*); template void Statement::constDispatch( OptOutConstDispatch*, - const Statement* const); -template void Val::constDispatch(OptOutConstDispatch, const Val* const); -template void Val::constDispatch(OptOutConstDispatch*, const Val* const); -template void Expr::constDispatch(OptOutConstDispatch, const Expr* const); -template void Expr::constDispatch(OptOutConstDispatch*, const Expr* const); + const Statement*); +template void Val::constDispatch(OptOutConstDispatch, const Val*); +template void Val::constDispatch(OptOutConstDispatch*, const Val*); +template void Expr::constDispatch(OptOutConstDispatch, const Expr*); +template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); template void Statement::constDispatch( OptInConstDispatch, - const Statement* const); + const Statement*); template void Statement::constDispatch( OptInConstDispatch*, - const Statement* const); -template void Val::constDispatch(OptInConstDispatch, const Val* const); -template void Val::constDispatch(OptInConstDispatch*, const Val* const); -template void Expr::constDispatch(OptInConstDispatch, const Expr* const); -template void Expr::constDispatch(OptInConstDispatch*, const Expr* const); + const Statement*); +template void Val::constDispatch(OptInConstDispatch, const Val*); +template void Val::constDispatch(OptInConstDispatch*, const Val*); +template void Expr::constDispatch(OptInConstDispatch, const Expr*); +template void Expr::constDispatch(OptInConstDispatch*, const Expr*); template Statement* Statement::mutatorDispatch(OptOutMutator, Statement*); template Statement* Statement::mutatorDispatch(OptOutMutator*, Statement*); @@ -377,23 +377,23 @@ void OptInDispatch::handle(Val* v) { Val::dispatch(this, v); } -void OptOutConstDispatch::handle(const Statement* const s) { +void OptOutConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } -void OptOutConstDispatch::handle(const Expr* const e) { +void OptOutConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } -void OptOutConstDispatch::handle(const Val* const v) { +void OptOutConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } -void OptInConstDispatch::handle(const Statement* const s) { +void OptInConstDispatch::handle(const Statement* s) { Statement::constDispatch(this, s); } -void OptInConstDispatch::handle(const Expr* const e) { +void OptInConstDispatch::handle(const Expr* e) { Expr::constDispatch(this, e); } -void OptInConstDispatch::handle(const Val* const v) { +void OptInConstDispatch::handle(const Val* v) { Val::constDispatch(this, v); } diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index 5a3aa89d82755a..f2ad5f7128dd8a 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -93,32 +93,32 @@ struct TORCH_CUDA_API OptOutConstDispatch { OptOutConstDispatch& operator=(OptOutConstDispatch&& other) = default; // Hierarchal dispatch functions for handle - virtual void handle(const Statement* const); - virtual void handle(const Expr* const); - virtual void handle(const Val* const); + virtual void handle(const Statement*); + virtual void handle(const Expr*); + virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain* const) {} - virtual void handle(const TensorDomain* const) {} - virtual void handle(const TensorView* const) {} - virtual void handle(const TensorIndex* const) {} - virtual void handle(const Bool* const) {} - virtual void handle(const Float* const) {} - virtual void handle(const Half* const) {} - virtual void handle(const Int* const) {} - virtual void handle(const NamedScalar* const) {} + virtual void handle(const IterDomain*) {} + virtual void handle(const TensorDomain*) {} + virtual void handle(const TensorView*) {} + virtual void handle(const TensorIndex*) {} + virtual void handle(const Bool*) {} + virtual void handle(const Float*) {} + virtual void handle(const Half*) {} + virtual void handle(const Int*) {} + virtual void handle(const NamedScalar*) {} // Exprs - virtual void handle(const Split* const) {} - virtual void handle(const Merge* const) {} - virtual void handle(const UnaryOp* const) {} - virtual void handle(const BinaryOp* const) {} - virtual void handle(const TernaryOp* const) {} - virtual void handle(const ReductionOp* const) {} - virtual void handle(const BroadcastOp* const) {} - virtual void handle(const ForLoop* const) {} - virtual void handle(const IfThenElse* const) {} - virtual void handle(const Allocate* const) {} + virtual void handle(const Split*) {} + virtual void handle(const Merge*) {} + virtual void handle(const UnaryOp*) {} + virtual void handle(const BinaryOp*) {} + virtual void handle(const TernaryOp*) {} + virtual void handle(const ReductionOp*) {} + virtual void handle(const BroadcastOp*) {} + virtual void handle(const ForLoop*) {} + virtual void handle(const IfThenElse*) {} + virtual void handle(const Allocate*) {} }; struct TORCH_CUDA_API OptOutDispatch { @@ -171,68 +171,68 @@ struct TORCH_CUDA_API OptInConstDispatch { OptInConstDispatch& operator=(OptInConstDispatch&& other) = default; // Hierarchal dispatch functions for handle - virtual void handle(const Statement* const); - virtual void handle(const Expr* const); - virtual void handle(const Val* const); + virtual void handle(const Statement*); + virtual void handle(const Expr*); + virtual void handle(const Val*); // Vals - virtual void handle(const IterDomain* const) { + virtual void handle(const IterDomain*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IterDomain."); } - virtual void handle(const TensorDomain* const) { + virtual void handle(const TensorDomain*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorDomain."); } - virtual void handle(const TensorView* const) { + virtual void handle(const TensorView*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorView."); } - virtual void handle(const TensorIndex* const) { + virtual void handle(const TensorIndex*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TensorIndex."); } - virtual void handle(const Bool* const) { + virtual void handle(const Bool*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Bool."); } - virtual void handle(const Float* const) { + virtual void handle(const Float*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Float."); } - virtual void handle(const Half* const) { + virtual void handle(const Half*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Half."); } - virtual void handle(const Int* const) { + virtual void handle(const Int*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Int."); } - virtual void handle(const NamedScalar* const) { + virtual void handle(const NamedScalar*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for NamedScalar."); } // Exprs - virtual void handle(const Split* const) { + virtual void handle(const Split*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Split."); } - virtual void handle(const Merge* const) { + virtual void handle(const Merge*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); } - virtual void handle(const UnaryOp* const) { + virtual void handle(const UnaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); } - virtual void handle(const BinaryOp* const) { + virtual void handle(const BinaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BinaryOp."); } - virtual void handle(const TernaryOp* const) { + virtual void handle(const TernaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for TernaryOp."); } - virtual void handle(const ReductionOp* const) { + virtual void handle(const ReductionOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ReductionOp."); } - virtual void handle(const BroadcastOp* const) { + virtual void handle(const BroadcastOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for BroadcastOp."); } - virtual void handle(const ForLoop* const) { + virtual void handle(const ForLoop*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for ForLoop."); } - virtual void handle(const Allocate* const) { + virtual void handle(const Allocate*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Allocate."); } - virtual void handle(const IfThenElse* const) { + virtual void handle(const IfThenElse*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for IfThenElse."); } }; From 33c7218e81ec386c4a795509858a5ab4b279bf12 Mon Sep 17 00:00:00 2001 From: Lemo Date: Wed, 17 Jun 2020 21:46:16 -0700 Subject: [PATCH 02/12] Adding new test case --- test/cpp/jit/test_gpu.cpp | 72 +++++++++++++++++++++++++++++++++++++++ test/cpp/jit/tests.h | 1 + 2 files changed, 73 insertions(+) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index a3410f5d863864..df679375edbd47 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -286,6 +286,78 @@ void testGPU_FusionExprEvalComplex() { checkIntValue(&eval_context, tv6->axis(2)->rawExtent(), 127); } +// Evaluate expressions post lowering +void testGPU_FusionExprEvalPostLower() { + Fusion fusion; + FusionGuard fg(&fusion); + + // Create a non-trivial IR + TensorView* tv0 = makeDummyTensor(2); + TensorView* tv1 = makeDummyTensor(2); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + TensorView* tv2 = add(tv1, new Float(2.0)); + TensorView* tv3 = add(tv0, tv2); + + fusion.addOutput(tv3); + + tv3->split(0, 4); + + tv0->computeAt(tv3, 1); + tv1->computeAt(tv3, 1); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unroll); + tv3->axis(1)->parallelize(ParallelType::Unroll); + tv2->axis(-1)->parallelize(ParallelType::TIDx); + tv3->axis(-1)->parallelize(ParallelType::TIDx); + + auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0)); + auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0)); + + fusion.printValuesMap(); + + // This appears to be causing issue; + GPULower gpulw(&fusion); + std::stringstream cdg; + gpulw.printKernel(cdg); + std::cout << cdg.str() << std::endl; + + fusion.printValuesMap(); + + // 1. Create an evaluation context + EvaluationContext eval_context(&fusion); + // 2. Bind values + // + // IMPORTANT: + // a. The bindings are only as stable as the Vals are in the fusion graph + // b. You must use the original (rootDomain) extents + // (ex. `tv0->getRootDomain()[0]->extent()` + // instead of `tv0->axis(0)->extent()`) + eval_context.bind(tv0->getRootDomain()[0]->extent(), 6); + eval_context.bind(tv0->getRootDomain()[1]->extent(), 128); + eval_context.bind(tv1->getRootDomain()[0]->extent(), 6); + eval_context.bind(tv1->getRootDomain()[1]->extent(), 128); + + // 3. Evaluate and check result values + TORCH_CHECK(tv2->domain()->nDims() == 3); + checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2); + checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4); + checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128); + + TORCH_CHECK(tv3->domain()->nDims() == 3); + checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2); + checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4); + checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128); + + const auto bid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context); + std::cout << "bid x value " << bid_x_val.value() << std::endl; + const auto tid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context); + std::cout << "tid x value " << tid_x_val.value() << std::endl; +} + void testGPU_FusionSimpleArith() { std::stringstream ss1, ss2; diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 67aebe42915a2f..98aaa3fd677158 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -110,6 +110,7 @@ namespace jit { _(GPU_FusionExprEvalBindings) \ _(GPU_FusionExprEvalBasic) \ _(GPU_FusionExprEvalComplex) \ + _(GPU_FusionExprEvalPostLower) \ _(GPU_FusionSimpleTypePromote) \ _(GPU_FusionMutator) \ _(GPU_FusionRegister) \ From 90ec35316f54b32ab77d16cdbea169dbc3dbee34 Mon Sep 17 00:00:00 2001 From: Lemo Date: Wed, 17 Jun 2020 21:48:44 -0700 Subject: [PATCH 03/12] Minor cleanup --- torch/csrc/jit/codegen/cuda/lower_loops.cpp | 4 ++-- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 5124b18b2f2811..2c0c6c5843ae7b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -61,12 +61,12 @@ Expr* LoopNestGenerator::pushAlloc(TensorView* tv) { // Multiply all the dimensions we're going to use for the allocation together // to get the total size - Val* size; + Val* size = nullptr; if (alloc_dims.size() == 0) { size = new Int(1); } else { size = alloc_dims[0]; - for (decltype(alloc_dims.size()) i{1}; i < alloc_dims.size(); i++) { + for (size_t i = 1; i < alloc_dims.size(); i++) { size = mul(size, alloc_dims[i]); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 208916ec0788fe..f6bb7640607e10 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -411,7 +411,7 @@ std::vector iterDomains(std::vector loops) { return ids; } -bool isTV(const Val* const val) { +bool isTV(const Val* val) { return val->getValType().value() == ValType::TensorView; } @@ -462,7 +462,7 @@ ForLoop* asForLoop(Statement* stmt) { return static_cast(expr); } -const TensorView* asConstTV(const Val* const val) { +const TensorView* asConstTV(const Val* val) { TORCH_INTERNAL_ASSERT(isTV(val)); return static_cast(val); } From e73b9c2a6820a6efc04d2c69842f63793e3a1290 Mon Sep 17 00:00:00 2001 From: Lemo Date: Wed, 17 Jun 2020 21:49:26 -0700 Subject: [PATCH 04/12] Adding Fusion::values_map_ --- torch/csrc/jit/codegen/cuda/fusion.cpp | 14 ++++ torch/csrc/jit/codegen/cuda/fusion.h | 21 +++++ torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 81 ++++++++++++------- torch/csrc/jit/codegen/cuda/ir_iostream.h | 75 +++++++++-------- .../jit/codegen/cuda/lower_validation.cpp | 53 +++++++----- 5 files changed, 157 insertions(+), 87 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index cde188d8f8df91..6c786adb5af9a6 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -218,6 +218,20 @@ void Fusion::print() { std::cout << "}\n"; } +void Fusion::printValuesMap() { + IRPrinter ir_printer(std::cout); + ir_printer.follow_val_map = false; + std::cout << "\nValues map\n"; + std::cout << "--------------------\n"; + for (const auto& kv : values_map_) { + ir_printer.handle(kv.first); + std::cout << " -> "; + ir_printer.handle(kv.second); + std::cout << "\n"; + } + std::cout << "--------------------\n\n"; +} + void Fusion::printMath() { FusionGuard fg(this); for (auto expr : exprs(true)) diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index cb3471fd9e74a7..f55c9541709b56 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -163,8 +163,12 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { // Print this fusion to cout. void print(); + // Print value mapping + void printValuesMap(); + // Print Arith exprs used in outputs void printMath(); + // Print transformations used in fusion (can be very verbose) void printTransforms(); @@ -208,6 +212,20 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { bool hasGridReduction(); size_t gridReductionTempBufferSize(); + void setValuesMap(std::unordered_map values_map) { + values_map_ = std::move(values_map); + } + + Val* loweredVal(Val* value) const { + auto it = values_map_.find(value); + return it != values_map_.end() ? it->second : value; + } + + const Val* loweredVal(const Val* value) const { + auto it = values_map_.find(const_cast(value)); + return it != values_map_.end() ? it->second : value; + } + private: // Sets of all Vals/Exprs registered with this fusion std::unordered_set val_set_; @@ -233,6 +251,9 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { // Dependency tracking for Vals. Where did it come from? Where is it used? std::unordered_map origin_; std::unordered_map> uses_; + + // Map a subset of values to the lowered equivalent (ex. sizes) + std::unordered_map values_map_; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 1e4d0b0bd64857..b255911c682d5b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -10,7 +10,7 @@ namespace fuser { namespace { // Make sure we can inline something, before we attempt to. -void check_inlineable(const IRInputOutput* const irio) { +void check_inlineable(const IRInputOutput* irio) { for (auto inp : irio->inputs()) TORCH_CHECK( inp->isScalar(), @@ -24,6 +24,23 @@ void check_inlineable(const IRInputOutput* const irio) { } } // namespace +void IRPrinter::handle(const Statement* s) { + OptInConstDispatch::handle(s); +} + +void IRPrinter::handle(const Val* v) { + if (follow_val_map) { + // Follow a single maping (permutation chains are not expected) + v = FusionGuard::getCurFusion()->loweredVal(v); + TORCH_INTERNAL_ASSERT(v == FusionGuard::getCurFusion()->loweredVal(v)); + } + OptInConstDispatch::handle(v); +} + +void IRPrinter::handle(const Expr* e) { + OptInConstDispatch::handle(e); +} + void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) { os << "__global__ void " << kernel_name_ << "("; @@ -85,7 +102,7 @@ void IRPrinter::handle(Fusion* fusion) { } } -void IRPrinter::handle(const TensorDomain* const td) { +void IRPrinter::handle(const TensorDomain* td) { os << "[ "; for (std::vector::size_type i = 0; i < td->nDims(); i++) { handle(td->axis(i)); @@ -95,7 +112,7 @@ void IRPrinter::handle(const TensorDomain* const td) { os << " ]"; } -void IRPrinter::handle(const TensorView* const tv) { +void IRPrinter::handle(const TensorView* tv) { os << "T" << tv->name(); handle(tv->domain()); @@ -106,7 +123,7 @@ void IRPrinter::handle(const TensorView* const tv) { } } -void IRPrinter::handle(const IterDomain* const id) { +void IRPrinter::handle(const IterDomain* id) { if (id->isReduction()) os << "r"; else if (id->isBroadcast()) @@ -138,7 +155,7 @@ void IRPrinter::handle(const IterDomain* const id) { os << "rf"; } -void IRPrinter::handle(const TensorIndex* const ti) { +void IRPrinter::handle(const TensorIndex* ti) { os << "T" << ti->view()->name() << "[ "; bool first = true; @@ -151,7 +168,7 @@ void IRPrinter::handle(const TensorIndex* const ti) { os << " ]"; } -void IRPrinter::handle(const Bool* const b) { +void IRPrinter::handle(const Bool* b) { if (print_inline_ && FusionGuard::getCurFusion()->origin(b) != nullptr) { os << "( "; handle(FusionGuard::getCurFusion()->origin(b)); @@ -166,7 +183,7 @@ void IRPrinter::handle(const Bool* const b) { } } -void IRPrinter::handle(const Float* const f) { +void IRPrinter::handle(const Float* f) { if (print_inline_ && FusionGuard::getCurFusion()->origin(f) != nullptr) { os << "( "; handle(FusionGuard::getCurFusion()->origin(f)); @@ -184,7 +201,7 @@ void IRPrinter::handle(const Float* const f) { } } -void IRPrinter::handle(const Half* const h) { +void IRPrinter::handle(const Half* h) { if (print_inline_ && FusionGuard::getCurFusion()->origin(h) != nullptr) { os << "( "; handle(FusionGuard::getCurFusion()->origin(h)); @@ -199,12 +216,19 @@ void IRPrinter::handle(const Half* const h) { } } -void IRPrinter::handle(const Int* const i) { - if (print_inline_ && FusionGuard::getCurFusion()->origin(i) != nullptr) { - os << "( "; - handle(FusionGuard::getCurFusion()->origin(i)); - os << " )"; - return; +void IRPrinter::handle(const Int* i) { + // Make sure we didn't bypass the value mapping + // (for example calling IRPrinter::handle() with a Int*) + TORCH_CHECK( + !follow_val_map || i == FusionGuard::getCurFusion()->loweredVal(i)); + + if (print_inline_) { + if (auto def = FusionGuard::getCurFusion()->origin(i)) { + os << "( "; + handle(def); + os << " )"; + return; + } } if (i->isSymbolic()) { @@ -214,27 +238,27 @@ void IRPrinter::handle(const Int* const i) { } } -void IRPrinter::handle(const NamedScalar* const i) { +void IRPrinter::handle(const NamedScalar* i) { os << i->name(); } namespace { -bool isTV(const Val* const val) { +bool isTV(const Val* val) { return ( val->getValType().value() == ValType::TensorView || val->getValType().value() == ValType::TensorIndex); } // Check if we're a TensorView op that we can generate code for. -bool isTVOp(const Expr* const expr) { +bool isTVOp(const Expr* expr) { if (expr->nOutputs() == 1 && isTV(expr->output(0))) return true; return false; } } // namespace -void IRPrinter::handle(const UnaryOp* const uop) { +void IRPrinter::handle(const UnaryOp* uop) { bool istvop = isTVOp(uop); if (!print_inline_) { indent(); @@ -280,7 +304,7 @@ void IRPrinter::handle(const UnaryOp* const uop) { os << ";\n"; } -void IRPrinter::handle(const BinaryOp* const bop) { +void IRPrinter::handle(const BinaryOp* bop) { bool istvop = isTVOp(bop); if (!print_inline_) { indent(); @@ -325,7 +349,7 @@ void IRPrinter::handle(const BinaryOp* const bop) { os << ";\n"; } -void IRPrinter::handle(const TernaryOp* const top) { +void IRPrinter::handle(const TernaryOp* top) { bool istvop = isTVOp(top); if (!print_inline_) { indent(); @@ -366,7 +390,7 @@ void IRPrinter::handle(const TernaryOp* const top) { os << ";\n"; } -void IRPrinter::handle(const ReductionOp* const rop) { +void IRPrinter::handle(const ReductionOp* rop) { // Check if we've lowered yet. bool lowered = rop->out()->getValType() == ValType::TensorIndex; @@ -448,7 +472,7 @@ void IRPrinter::handle(const ReductionOp* const rop) { } } -void IRPrinter::handle(const BroadcastOp* const bop) { +void IRPrinter::handle(const BroadcastOp* bop) { indent(); handle(bop->out()); os << "\n"; @@ -460,7 +484,7 @@ void IRPrinter::handle(const BroadcastOp* const bop) { os << ";\n"; } -void IRPrinter::handle(const ForLoop* const fl) { +void IRPrinter::handle(const ForLoop* fl) { if (fl->iter_domain()->isThread()) { for (auto& expr : fl->constBody().exprs()) handle(expr); @@ -488,7 +512,7 @@ void IRPrinter::handle(const ForLoop* const fl) { os << "}\n"; } -void IRPrinter::handle(const IfThenElse* const ite) { +void IRPrinter::handle(const IfThenElse* ite) { indent(); // IF @@ -516,7 +540,7 @@ void IRPrinter::handle(const IfThenElse* const ite) { os << "}\n"; } -void IRPrinter::handle(const Allocate* const a) { +void IRPrinter::handle(const Allocate* a) { indent(); os << a->buf_type(); if (a->buffer()->getValType() == ValType::TensorView) { @@ -537,7 +561,7 @@ void IRPrinter::handle(const Allocate* const a) { } } -void IRPrinter::handle(const Split* const s) { +void IRPrinter::handle(const Split* s) { os << "Split: "; handle(s->in()); os << " by factor " << s->factor() << " -> "; @@ -547,7 +571,7 @@ void IRPrinter::handle(const Split* const s) { os << "\n"; } -void IRPrinter::handle(const Merge* const m) { +void IRPrinter::handle(const Merge* m) { os << "Merge: "; handle(m->outer()); os << " and "; @@ -601,7 +625,6 @@ void IRPrinter::printKernel( const std::vector& exprs, const std::string& kernel_name) { Fusion* fusion = FusionGuard::getCurFusion(); - printReductionOps(fusion); printHeader(fusion, kernel_name); for (auto* expr : exprs) { @@ -610,7 +633,7 @@ void IRPrinter::printKernel( os << "}\n"; } -std::ostream& operator<<(std::ostream& os, const Statement* const stmt) { +std::ostream& operator<<(std::ostream& os, const Statement* stmt) { IRPrinter p(os); p.handle(stmt); return os; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 5bfce2a6b6b727..ebb8e037e4be4b 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -56,6 +56,9 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { // Track the indentation size for pretty printing int indent_size = 0; + // Handle value mapping + bool follow_val_map = true; + // Indent the generated code void indent() { for (int i = 0; i < indent_size; i++) @@ -70,54 +73,48 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { IRPrinter(std::ostream& _os) : os(_os) {} - virtual void handle(Fusion* const f); + virtual void handle(Fusion* f); // handle calls some non const fusion ops, // eventhough fusion should remain unchanged. // Need to look into this. - virtual void handle(const Fusion* const f) { + virtual void handle(const Fusion* f) { handle(const_cast(f)); } + virtual void handle(Fusion& f) { handle(&f); } - virtual void handle(const Statement* const s) { - OptInConstDispatch::handle(s); - }; - - virtual void handle(const Val* const v) { - OptInConstDispatch::handle(v); - }; - virtual void handle(const Expr* const e) { - OptInConstDispatch::handle(e); - }; - - virtual void handle(const TensorDomain* const) override; - virtual void handle(const TensorView* const) override; - virtual void handle(const IterDomain* const) override; - virtual void handle(const TensorIndex* const) override; - - virtual void handle(const Bool* const) override; - virtual void handle(const Float* const) override; - virtual void handle(const Half* const) override; - virtual void handle(const Int* const) override; - virtual void handle(const NamedScalar* const) override; - - virtual void handle(const UnaryOp* const) override; - virtual void handle(const BinaryOp* const) override; - virtual void handle(const TernaryOp* const) override; - virtual void handle(const ReductionOp* const) override; - virtual void handle(const BroadcastOp* const) override; - - virtual void handle(const ForLoop* const) override; - virtual void handle(const IfThenElse* const) override; - virtual void handle(const Allocate* const) override; - - virtual void handle(const Split* const) override; - virtual void handle(const Merge* const) override; - - void print_inline(const Statement* const stmt) { + void handle(const Statement* s) override; + void handle(const Val* v) override; + void handle(const Expr* e) override; + + void handle(const TensorDomain*) override; + void handle(const TensorView*) override; + void handle(const IterDomain*) override; + void handle(const TensorIndex*) override; + + void handle(const Bool*) override; + void handle(const Float*) override; + void handle(const Half*) override; + void handle(const Int*) override; + void handle(const NamedScalar*) override; + + void handle(const UnaryOp*) override; + void handle(const BinaryOp*) override; + void handle(const TernaryOp*) override; + void handle(const ReductionOp*) override; + void handle(const BroadcastOp*) override; + + void handle(const ForLoop*) override; + void handle(const IfThenElse*) override; + void handle(const Allocate*) override; + + void handle(const Split*) override; + void handle(const Merge*) override; + + void print_inline(const Statement* stmt) { bool prev = print_inline_; print_inline_ = true; handle(stmt); @@ -133,7 +130,7 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { TORCH_CUDA_API std::ostream& operator<<( std::ostream& os, - const Statement* const stmt); + const Statement* stmt); TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion* f); TORCH_CUDA_API std::ostream& operator<<(std::ostream& os, Fusion& f); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 070c1dbe2d61c1..c47d3d731619d0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -63,24 +63,20 @@ void IRFixComputeAt(Fusion* fusion) { // tensors to reference the runtime structure containing sizes. void IRReplaceSizes() { Fusion* fusion = FusionGuard::getCurFusion(); + // Sizes of inputs/outputs -> T.size[...] std::unordered_map size_map; // Grab inputs and outputs - std::vector orig_inp_out; - std::vector all_tvs; - - for (auto* val : fusion->inputs()) - if (ir_utils::isTV(val)) - orig_inp_out.push_back(ir_utils::asTV(val)); - - for (auto* val : fusion->outputs()) - if (ir_utils::isTV(val)) - orig_inp_out.push_back(ir_utils::asTV(val)); - - for (auto* val : fusion->deterministic_vals()) { + std::vector inputs_and_outputs; + for (auto val : fusion->inputs()) { if (ir_utils::isTV(val)) { - all_tvs.push_back(ir_utils::asTV(val)); + inputs_and_outputs.push_back(val->as()); + } + } + for (auto val : fusion->outputs()) { + if (ir_utils::isTV(val)) { + inputs_and_outputs.push_back(val->as()); } } @@ -96,12 +92,12 @@ void IRReplaceSizes() { // TensorView wouldn't change, so users pointers will remain valid. The other // option which seems less elegant but would also work is build up the domain // on the new tensor, and then simply replace it into the original one. - for (TensorView* tv : orig_inp_out) { + for (auto tv : inputs_and_outputs) { // Replace the domain with one based on Ti.size[j] std::vector new_domain_iters; - const std::vector& root_td = tv->getRootDomain(); + auto root_td = tv->getRootDomain(); - for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) { + for (size_t i = 0; i < root_td.size(); i++) { // Output sizes could have reduction axes, which isn't what gets output. if (root_td[i]->isReduction()) continue; @@ -118,9 +114,27 @@ void IRReplaceSizes() { } } - // If we already lowered all inputs/outputs we can just return. - if (size_map.size() == 0) - return; +#if 1 + // Adjust memory types to make sure they are valid + for (auto val : fusion->deterministic_vals()) { + if (ir_utils::isTV(val)) { + auto tv = val->as(); + if (fusion->hasInput(tv) || fusion->hasOutput(tv)) { + tv->setMemoryType(MemoryType::Global); + } else if (tv->getMemoryType() == MemoryType::Global) { + tv->setMemoryType(MemoryType::Local); + } + } + } + + fusion->setValuesMap(size_map); +#else + std::vector all_tvs; + for (auto* val : fusion->deterministic_vals()) { + if (ir_utils::isTV(val)) { + all_tvs.push_back(ir_utils::asTV(val)); + } + } // Set domains to be based on symbolic sizes (i.e. Ti.size[...]) for (TensorView* tv : all_tvs) { @@ -172,6 +186,7 @@ void IRReplaceSizes() { tv->setMemoryType(MemoryType::Local); } } +#endif } void PrepareForLowering(Fusion* fusion) { From 739a47ee0bdf74900a1b83abdce3cd5e475b53db Mon Sep 17 00:00:00 2001 From: Lemo Date: Thu, 18 Jun 2020 15:44:38 -0700 Subject: [PATCH 05/12] Update expected output for GPU_FusionParser_CUDA --- test/cpp/jit/test_gpu.cpp | 117 ++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 790870471dc486..3472726657422b 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -321,9 +321,9 @@ void testGPU_FusionExprEvalPostLower() { // This appears to be causing issue; GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); - std::cout << cdg.str() << std::endl; + std::stringstream actual_kernel; + gpulw.printKernel(actual_kernel); + std::cout << actual_kernel.str() << std::endl; fusion.printValuesMap(); @@ -801,51 +801,56 @@ void testGPU_FusionParser() { prog.device_ = 0; fuser::cuda::parseJitIR(g, &prog); - std::stringstream ref; - ref << "__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){\n" - << " float T2[4];\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n" - << " T2[ i64 ]\n" - << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" - << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" - << " }\n" - << " } else { \n" - << " for(size_t i64 = 0; i64 < 4; ++i64 ) {\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " T2[ i64 ]\n" - << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" - << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i64 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" - << " }\n" - << " }\n" - << " }\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n" - << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" - << " = T2[ i65 ]\n" - << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" - << " }\n" - << " } else { \n" - << " for(size_t i65 = 0; i65 < 4; ++i65 ) {\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" - << " = T2[ i65 ]\n" - << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i65 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" - << " }\n" - << " }\n" - << " }\n" - << "}\n"; + // CONSIDER: + // 1. this can be moved to a dedicated "golden" file + // 2. use a fuzzy compare (ignore non-significant whitespaces for example) + const std::string expected_kernel = R"( +__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ + float T2[4]; + if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + for(size_t i40 = 0; i40 < 4; ++i40 ) { + T2[ i40 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + } + } else { + for(size_t i40 = 0; i40 < 4; ++i40 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T2[ i40 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + } + } + } + if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + for(size_t i41 = 0; i41 < 4; ++i41 ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i41 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + } + } else { + for(size_t i41 = 0; i41 < 4; ++i41 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i41 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + } + } + } +} +)"; GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); - if (ref.str().size() != cdg.str().size() || - ref.str().compare(cdg.str()) != 0) { + std::stringstream actual_kernel; + actual_kernel << "\n"; + gpulw.printKernel(actual_kernel); + if (expected_kernel.size() != actual_kernel.str().size() || + expected_kernel.compare(actual_kernel.str()) != 0) { std::cerr << " Codegen mismatch, codegen possibly changed, or is incorrect. " - << " \n ========= REF ========= \n" - << ref.str() << "\n========= RESULT ========== \n" - << cdg.str() << "\n=================" << std::endl; + << " \n ========= EXPECTED ========= \n" + << expected_kernel << "\n========= ACTUAL ========== \n" + << actual_kernel.str() << "\n=================" << std::endl; TORCH_CHECK(false); } } @@ -1272,10 +1277,10 @@ void testGPU_FusionAdvancedComputeAt() { &prog, {t0}, {kernel_tv5, kernel_tv6}); GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); + std::stringstream actual_kernel; + gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv5, t5), cdg.str()); + TORCH_CHECK(at::allclose(kernel_tv5, t5), actual_kernel.str()); TORCH_CHECK(at::allclose(kernel_tv6, t6)); } @@ -1339,10 +1344,10 @@ void testGPU_FusionAdvancedComputeAt() { torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {kernel_tv3}); GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); + std::stringstream actual_kernel; + gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv3, t3), cdg.str()); + TORCH_CHECK(at::allclose(kernel_tv3, t3), actual_kernel.str()); } // Case 4 @@ -1419,10 +1424,10 @@ void testGPU_FusionAdvancedComputeAt() { &prog, {t0, t1, t2, t3}, {kernel_tv6}); GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); + std::stringstream actual_kernel; + gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv6, t6), cdg.str()); + TORCH_CHECK(at::allclose(kernel_tv6, t6), actual_kernel.str()); } } @@ -1519,10 +1524,10 @@ void testGPU_FusionScalarInputs() { {kernel_tv4}); GPULower gpulw(&fusion); - std::stringstream cdg; - gpulw.printKernel(cdg); + std::stringstream actual_kernel; + gpulw.printKernel(actual_kernel); - TORCH_CHECK(at::allclose(kernel_tv4, t4), cdg.str()); + TORCH_CHECK(at::allclose(kernel_tv4, t4), actual_kernel.str()); } void testGPU_FusionLoopUnroll() { From 0eec316cda083a0734c6ad92afe1fd92ab06a891 Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 09:57:19 -0700 Subject: [PATCH 06/12] Cleanup --- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 22 +++++++++---------- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 4424e62aaabde4..8ff828f759c59f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -29,7 +29,7 @@ Expr* Statement::asExpr() { return static_cast(this); } -void Statement::print() { +void Statement::print() const { IRPrinter ir_printer(std::cout); ir_printer.handle(this); std::cout << std::endl; @@ -55,33 +55,33 @@ struct ConstCheck : OptOutConstDispatch { private: bool is_const_ = true; - void handle(const Bool* const b) override { + void handle(const Bool* b) override { is_const_ = is_const_ && b->isConst(); } - void handle(const Float* const f) override { + void handle(const Float* f) override { is_const_ = is_const_ && f->isConst(); } - void handle(const Half* const h) override { + void handle(const Half* h) override { is_const_ = is_const_ && h->isConst(); } - void handle(const Int* const i) override { + void handle(const Int* i) override { is_const_ = is_const_ && i->isConst(); } - void handle(const NamedScalar* const ns) override { + void handle(const NamedScalar* ns) override { is_const_ = is_const_ && false; } - void handle(const Expr* const expr) override { + void handle(const Expr* expr) override { for (auto inp : expr->inputs()) { OptOutConstDispatch::handle(inp); } } - void handle(const Val* const val) override { + void handle(const Val* val) override { const Expr* orig = FusionGuard::getCurFusion()->origin(val); if (orig != nullptr) handle(orig); @@ -90,7 +90,7 @@ struct ConstCheck : OptOutConstDispatch { } public: - static bool isConst(const Val* const val) { + static bool isConst(const Val* val) { ConstCheck cc; cc.handle(val); return cc.is_const_; @@ -183,14 +183,14 @@ void Scope::clear() { this->exprs_ = std::vector(); } -bool IRInputOutput::hasInput(const Val* const input) const { +bool IRInputOutput::hasInput(const Val* input) const { for (auto val : inputs_) if (val == input) return true; return false; } -bool IRInputOutput::hasOutput(const Val* const output) const { +bool IRInputOutput::hasOutput(const Val* output) const { for (auto val : outputs_) if (val == output) return true; diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index f19f30a9e5e7d4..f38261df19949e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -142,7 +142,7 @@ struct TORCH_CUDA_API Statement { return this == other; } - void print(); + void print() const; protected: StmtNameType name_ = UNINITIALIZED_STMTNAMETYPE; From 13f2a36bf1f70af0191bc6fcaa6f1ec8c3549c05 Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 10:07:14 -0700 Subject: [PATCH 07/12] Switch IRReplaceSizes() off --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index a361c61f986705..a9f185130f458d 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -64,15 +64,13 @@ IndexCompute::IndexCompute( return; } - bool exclude_reduction = td->nDims() > indices.size(); + const bool exclude_reduction = td->nDims() > indices.size(); TORCH_INTERNAL_ASSERT( td->noReductions().size() == indices.size() || td->nDims() == indices.size(), "For IndexCompute the number of axes should match the number of dimensions in the TensorDomain."); - TORCH_INTERNAL_ASSERT(!td->hasRFactor(), "Not implemented yet."); - { size_t i = 0; for (auto id : td->domain()) { @@ -82,7 +80,7 @@ IndexCompute::IndexCompute( } } - std::vector domain_vals(td->domain().begin(), td->domain().end()); + const std::vector domain_vals(td->domain().begin(), td->domain().end()); // Run the split/merge operations backwards. This will modify the index_map_ // so it can be used to index the root TensorDomain. Each entry in the root @@ -92,15 +90,10 @@ IndexCompute::IndexCompute( // map at the rfactor IterDomains. traverseFrom(indices[0]->fusion(), domain_vals, false); - std::vector inds; for (auto id : td->rootDomain()) { if (exclude_reduction && id->isReduction()) continue; - auto it = index_map_.find(id); - TORCH_INTERNAL_ASSERT( - it != index_map_.end(), - "Error during index compute, missed computing a value."); - indices_.push_back(it->second); + indices_.push_back(index_map_.at(id)); } } From ed9215ca7875a19f100c9dcea4bbecb618b5a016 Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 10:25:00 -0700 Subject: [PATCH 08/12] Cleanup ir_validation.cpp --- .../jit/codegen/cuda/ir_interface_nodes.h | 4 +- torch/csrc/jit/codegen/cuda/lower_unroll.cpp | 3 +- .../jit/codegen/cuda/lower_validation.cpp | 101 +++--------------- .../csrc/jit/codegen/cuda/lower_validation.h | 11 +- 4 files changed, 24 insertions(+), 95 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index dd5d8b7180a8ad..e535788be9e04e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -288,8 +288,8 @@ struct TORCH_CUDA_API TensorView : public Val { friend TORCH_CUDA_API TransformReplay; friend TORCH_CUDA_API OptOutMutator; friend TORCH_CUDA_API LoopNestGenerator; - friend void IRFixComputeAt(Fusion*); - friend void IRReplaceSizes(); + friend void IrFixComputeAt(Fusion*); + friend void IrAdjustMemoryTypes(Fusion* fusion); protected: // Make an exact copy of this tensor (similar to clone()), however, also grabs diff --git a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp index 2744ce56ef8c18..91c0d2894621cb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_unroll.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_unroll.cpp @@ -184,8 +184,7 @@ void UnrollPass::handle(ForLoop* fl) { } // for expr } else { // if(!within_unroll) // modify in place, so grab a copy of exprs first. - std::vector exprs( - fl->body().exprs().begin(), fl->body().exprs().end()); + const std::vector exprs = fl->body().exprs(); for (auto expr : exprs) { if (!ir_utils::isTVOp(expr)) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 737d3ad26b2d13..62085989aff3c8 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -7,11 +7,8 @@ namespace torch { namespace jit { namespace fuser { -namespace { - // Some pre-compilation checks -void IRValidate(Fusion* fusion) { - FusionGuard fg(fusion); +static void IrValidate(Fusion* fusion) { fusion->validateInputs(); for (Val* val : fusion->vals()) { if (ir_utils::isTV(val)) { @@ -27,16 +24,12 @@ void IRValidate(Fusion* fusion) { "."); } } - } // if ir_utils::isTV - } // for(Val* val : fusion->vals()) -} // validate - -} // namespace + } + } +} // Remove circular computeAt references -void IRFixComputeAt(Fusion* fusion) { - FusionGuard fg(fusion); - +void IrFixComputeAt(Fusion* fusion) { std::vector exprs = fusion->exprs(true); std::set visited; for (auto it = exprs.rbegin(); it != exprs.rend(); it++) { @@ -55,15 +48,7 @@ void IRFixComputeAt(Fusion* fusion) { } } -// TensorViews are all based on symbolic sizes. When we first initialize them we -// don't know if they're inputs or outputs which would mean that they have -// runtime shapes. Intermediate tensors (those not going to global memory) do -// not have this information. Since we need to have the correct information in -// the kernel being fetched for shapes, we want to replace input and output -// tensors to reference the runtime structure containing sizes. -void IRReplaceSizes() { - Fusion* fusion = FusionGuard::getCurFusion(); - +void IrBuildSizesMap(Fusion* fusion) { // Sizes of inputs/outputs -> T.size[...] std::unordered_map size_map; @@ -115,8 +100,10 @@ void IRReplaceSizes() { } } -#if 1 - // Adjust memory types to make sure they are valid + fusion->setValuesMap(size_map); +} + +void IrAdjustMemoryTypes(Fusion* fusion) { for (auto val : fusion->deterministic_vals()) { if (ir_utils::isTV(val)) { auto tv = val->as(); @@ -127,75 +114,15 @@ void IRReplaceSizes() { } } } - - fusion->setValuesMap(size_map); -#else - std::vector all_tvs; - for (auto* val : fusion->deterministic_vals()) { - if (ir_utils::isTV(val)) { - all_tvs.push_back(ir_utils::asTV(val)); - } - } - - // Set domains to be based on symbolic sizes (i.e. Ti.size[...]) - for (TensorView* tv : all_tvs) { - std::vector new_domain_iters; - const std::vector& root_td = tv->getRootDomain(); - - for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) { - Val* new_size = root_td[i]->extent(); - if (size_map.find(new_size) != size_map.end()) - new_size = size_map[new_size]; - - new_domain_iters.push_back(new IterDomain( - root_td[i]->start(), - new_size, - root_td[i]->parallel_method(), - root_td[i]->isReduction(), - root_td[i]->isRFactorProduct(), - root_td[i]->isBroadcast())); - } - - TensorDomain* old_domain = tv->domain(); - TensorDomain* new_domain = new TensorDomain(new_domain_iters); - - // We should just be able to replace sizes in place, but mutator is setup to - // do that as it set up to replace vals in Exprs, but - // IterDomain/TensorDomain are vals. - - new_domain = TransformReplay::fullSelfReplay(new_domain, old_domain); - - TORCH_INTERNAL_ASSERT( - old_domain->nDims() == new_domain->nDims(), - "Tried to set symbolic sizes through the kernel, but hit a snag, Replayed domain should be the same size as the target domain, but got ", - new_domain->nDims(), - " and ", - old_domain->nDims()); - // Parallelize all iter domains - for (decltype(new_domain->nDims()) i{0}; i < new_domain->nDims(); i++) - new_domain->axis(i)->parallelize(old_domain->axis(i)->parallel_method()); - - tv->setDomain(new_domain); - } - - // Adjust memory types to make sure they are valid - for (TensorView* tv : all_tvs) { - if (fusion->hasInput(tv) || fusion->hasOutput(tv)) { - tv->setMemoryType(MemoryType::Global); - } else { - if (tv->getMemoryType() == MemoryType::Global) - tv->setMemoryType(MemoryType::Local); - } - } -#endif } void PrepareForLowering(Fusion* fusion) { FusionGuard fg(fusion); - IRFixComputeAt(fusion); - IRValidate(fusion); - IRReplaceSizes(); + IrFixComputeAt(fusion); + IrValidate(fusion); + IrBuildSizesMap(fusion); + IrAdjustMemoryTypes(fusion); } } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 350aef1cc5b05c..83a39d3cb94351 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -26,15 +26,18 @@ void TORCH_CUDA_API PrepareForLowering(Fusion* fusion); // Compute at can have some circular references. Before we can call any tv // with tv->getComputeAtAxis(i) we need to break those circular dependencies. -void IRFixComputeAt(Fusion* fusion); +void IrFixComputeAt(Fusion* fusion); -// TensorViews are all based on symbolic sizes. When we first initialize them -// we don't know if they're inputs or outputs which would mean that they have +// TensorViews are all based on symbolic sizes. When we first initialize them we +// don't know if they're inputs or outputs which would mean that they have // runtime shapes. Intermediate tensors (those not going to global memory) do // not have this information. Since we need to have the correct information in // the kernel being fetched for shapes, we want to replace input and output // tensors to reference the runtime structure containing sizes. -void IRReplaceSizes(); +void IrBuildSizesMap(Fusion* fusion); + +// Adjust memory types to make sure they are valid +void IrAdjustMemoryTypes(Fusion* fusion); } // namespace fuser } // namespace jit From bcb02157a2a2f017ca7b092faef4e380e516ead2 Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 10:31:24 -0700 Subject: [PATCH 09/12] Cleanup the new test case --- test/cpp/jit/test_gpu.cpp | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 3472726657422b..11028bd1deef28 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -317,25 +317,15 @@ void testGPU_FusionExprEvalPostLower() { auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0)); auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0)); - fusion.printValuesMap(); - - // This appears to be causing issue; + // Lower GPULower gpulw(&fusion); - std::stringstream actual_kernel; - gpulw.printKernel(actual_kernel); - std::cout << actual_kernel.str() << std::endl; - - fusion.printValuesMap(); + std::stringstream kernel; + gpulw.printKernel(kernel); // 1. Create an evaluation context EvaluationContext eval_context(&fusion); + // 2. Bind values - // - // IMPORTANT: - // a. The bindings are only as stable as the Vals are in the fusion graph - // b. You must use the original (rootDomain) extents - // (ex. `tv0->getRootDomain()[0]->extent()` - // instead of `tv0->axis(0)->extent()`) eval_context.bind(tv0->getRootDomain()[0]->extent(), 6); eval_context.bind(tv0->getRootDomain()[1]->extent(), 128); eval_context.bind(tv1->getRootDomain()[0]->extent(), 6); @@ -352,10 +342,8 @@ void testGPU_FusionExprEvalPostLower() { checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4); checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128); - const auto bid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context); - std::cout << "bid x value " << bid_x_val.value() << std::endl; - const auto tid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context); - std::cout << "tid x value " << tid_x_val.value() << std::endl; + checkIntValue(&eval_context, bid_x, 2); + checkIntValue(&eval_context, tid_x, 128); } void testGPU_FusionSimpleArith() { From d0d4ededa31e465102c36b244a772d61282d6c9e Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 10:40:14 -0700 Subject: [PATCH 10/12] Update comments --- torch/csrc/jit/codegen/cuda/lower_validation.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 83a39d3cb94351..650f23134c9a22 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -9,17 +9,18 @@ namespace jit { namespace fuser { /* - * Currently this does 3 things: + * Currently this does the following: * * (1) Run a validation pass on the IR making sure there are no mistakes or * unsupported scheduling. * - * (2) Replace symbolic sizes for global memory with named scalars + * (2) Creates a mapping for symbolic sizes to named scalars * i.e. T0[i0] -> T0[T0.size[0]] * * (3) Change computeAt structure to make sure computeAt structure follows the * expression structure. - * + * + * (4) Adjust TensorView memory types to make sure they are valid */ void TORCH_CUDA_API PrepareForLowering(Fusion* fusion); From 4a6402e5a5edc244578430bffd75d3af3bbfd93f Mon Sep 17 00:00:00 2001 From: Lemo Date: Fri, 19 Jun 2020 10:41:29 -0700 Subject: [PATCH 11/12] Fix formatting --- torch/csrc/jit/codegen/cuda/dispatch.cpp | 16 ++++------------ torch/csrc/jit/codegen/cuda/fusion.h | 2 +- torch/csrc/jit/codegen/cuda/ir_iostream.h | 2 +- torch/csrc/jit/codegen/cuda/lower_validation.h | 2 +- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 24f81030cfb2d2..2ec01fc34f8c9f 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -321,23 +321,15 @@ template void Val::dispatch(OptInDispatch*, Val*); template void Expr::dispatch(OptInDispatch, Expr*); template void Expr::dispatch(OptInDispatch*, Expr*); -template void Statement::constDispatch( - OptOutConstDispatch, - const Statement*); -template void Statement::constDispatch( - OptOutConstDispatch*, - const Statement*); +template void Statement::constDispatch(OptOutConstDispatch, const Statement*); +template void Statement::constDispatch(OptOutConstDispatch*, const Statement*); template void Val::constDispatch(OptOutConstDispatch, const Val*); template void Val::constDispatch(OptOutConstDispatch*, const Val*); template void Expr::constDispatch(OptOutConstDispatch, const Expr*); template void Expr::constDispatch(OptOutConstDispatch*, const Expr*); -template void Statement::constDispatch( - OptInConstDispatch, - const Statement*); -template void Statement::constDispatch( - OptInConstDispatch*, - const Statement*); +template void Statement::constDispatch(OptInConstDispatch, const Statement*); +template void Statement::constDispatch(OptInConstDispatch*, const Statement*); template void Val::constDispatch(OptInConstDispatch, const Val*); template void Val::constDispatch(OptInConstDispatch*, const Val*); template void Expr::constDispatch(OptInConstDispatch, const Expr*); diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index 36977e76f979f2..b46e20b622b96a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -168,7 +168,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { // Print Arith exprs used in outputs void printMath(); - + // Print transformations used in fusion (can be very verbose) void printTransforms(); // Lower the fusion and print a kernel diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index ebb8e037e4be4b..a407c3cab94845 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -81,7 +81,7 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { virtual void handle(const Fusion* f) { handle(const_cast(f)); } - + virtual void handle(Fusion& f) { handle(&f); } diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 650f23134c9a22..6990012a51cbc7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -19,7 +19,7 @@ namespace fuser { * * (3) Change computeAt structure to make sure computeAt structure follows the * expression structure. - * + * * (4) Adjust TensorView memory types to make sure they are valid */ From cd17cd0ae1b7dcc151b03dfd0354436ae75f7d8c Mon Sep 17 00:00:00 2001 From: Lemo Date: Mon, 22 Jun 2020 10:21:05 -0700 Subject: [PATCH 12/12] Incorporate feedback --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 6 +++++- torch/csrc/jit/codegen/cuda/lower_validation.cpp | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index a9f185130f458d..0a78acdf5fcf15 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -93,7 +93,11 @@ IndexCompute::IndexCompute( for (auto id : td->rootDomain()) { if (exclude_reduction && id->isReduction()) continue; - indices_.push_back(index_map_.at(id)); + auto it = index_map_.find(id); + TORCH_INTERNAL_ASSERT( + it != index_map_.end(), + "Error during index compute, missed computing a value."); + indices_.push_back(it->second); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 62085989aff3c8..31a2fc27a95e14 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -77,10 +77,10 @@ void IrBuildSizesMap(Fusion* fusion) { // TensorView wouldn't change, so users pointers will remain valid. The other // option which seems less elegant but would also work is build up the domain // on the new tensor, and then simply replace it into the original one. - for (auto tv : inputs_and_outputs) { + for (TensorView* tv : inputs_and_outputs) { // Replace the domain with one based on Ti.size[j] std::vector new_domain_iters; - auto root_td = tv->getRootDomain(); + const std::vector& root_td = tv->getRootDomain(); size_t dim = 0; for (auto id : root_td) {