From 0b00077239108e6db9de050135b790996bb0eea6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 16 Mar 2021 16:40:33 -0700 Subject: [PATCH 1/7] Add a repro for issue #757 --- test/cpp/jit/test_gpu.cpp | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index edff04b20112..448ca865e07f 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -13936,6 +13936,43 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { "Only tv3 should be included"); } +TEST(NVFuserTest, FusionIssue757_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = broadcast(tv1, {false, true}); + auto tv3 = makeSymbolicTensor(2); + fusion.addInput(tv3); + auto tv4 = add(tv2, tv3); + fusion.addOutput(tv4); + + tv1->computeAt(tv4, -1); + + tv4->axis(-1)->parallelize(ParallelType::TIDx); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + int numel_x = 650; + int numel_y = 102; + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({numel_x, numel_y}, options); + at::Tensor t3 = at::randn({numel_x, numel_y}, options); + std::vector inputs = {t0, t3}; + + FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion(inputs); + + auto t1 = t0.sum({1}); + auto t2 = t1.unsqueeze(-1).expand({numel_x, numel_y}); + auto t4 = t2 + t3; + + testValidate(&fusion, outputs, inputs, {t4}, __LINE__, __FILE__); +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA) From f267cb82708721ce90b5ac68e24c4282bc37adb9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 16 Mar 2021 16:41:01 -0700 Subject: [PATCH 2/7] Parallelize all IterDomains when inferred by computeAt relationships --- torch/csrc/jit/codegen/cuda/compute_at_map.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 0a1fbffc9cf5..55f88fc59071 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -357,18 +357,15 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { TORCH_INTERNAL_ASSERT( concrete_id != nullptr, "Could not concretize an IterDomain set."); - // If parallel mode, parallelize the the concrete id - // TODO: Would be good to simply keep a parallelization map and make lookups - // to it through lowering. - if (mapping_mode_ == MappingMode::PARALLEL) { - auto parallel_map_it = parallel_type_map_.find(set); - if (parallel_map_it != parallel_type_map_.end()) { - concrete_id->parallelize(parallel_map_it->second); - } - } - for (auto id : *set) { concrete_id_map_[id] = concrete_id; + if (mapping_mode_ == MappingMode::PARALLEL) { + auto parallel_map_it = parallel_type_map_.find(set); + // Parallelize all IterDomains to simplify lowering and codegen + if (parallel_map_it != parallel_type_map_.end()) { + id->parallelize(parallel_map_it->second); + } + } } } From ee0bc8eb285772055bc76c5c6327d52f879393ed Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Mar 2021 13:23:14 -0700 Subject: [PATCH 3/7] Revert "Parallelize all IterDomains when inferred by computeAt relationships" This reverts commit ab83e3e6367ab186498b2d0ab81ca09dcb52f434. --- torch/csrc/jit/codegen/cuda/compute_at_map.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 55f88fc59071..0a1fbffc9cf5 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -357,15 +357,18 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { TORCH_INTERNAL_ASSERT( concrete_id != nullptr, "Could not concretize an IterDomain set."); + // If parallel mode, parallelize the the concrete id + // TODO: Would be good to simply keep a parallelization map and make lookups + // to it through lowering. + if (mapping_mode_ == MappingMode::PARALLEL) { + auto parallel_map_it = parallel_type_map_.find(set); + if (parallel_map_it != parallel_type_map_.end()) { + concrete_id->parallelize(parallel_map_it->second); + } + } + for (auto id : *set) { concrete_id_map_[id] = concrete_id; - if (mapping_mode_ == MappingMode::PARALLEL) { - auto parallel_map_it = parallel_type_map_.find(set); - // Parallelize all IterDomains to simplify lowering and codegen - if (parallel_map_it != parallel_type_map_.end()) { - id->parallelize(parallel_map_it->second); - } - } } } From 2caed183f1639f2ce2373b318f66938a5e70a463 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Mar 2021 16:19:24 -0700 Subject: [PATCH 4/7] Add parallelism information to kir::BroadcastOp There is no easy way to know which parallel types are used for kir::TensorView after the lowering as the ComputeAt parallel map is not maintained. Adds that information to kir::BroadcastOp as it is needed for codegen. --- torch/csrc/jit/codegen/cuda/codegen.cpp | 4 +--- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 8 ++++++-- torch/csrc/jit/codegen/cuda/kernel_ir.h | 11 +++++++++- torch/csrc/jit/codegen/cuda/lower2device.cpp | 13 +++++++----- torch/csrc/jit/codegen/cuda/lower2device.h | 1 + torch/csrc/jit/codegen/cuda/lower_index.cpp | 2 +- .../codegen/cuda/lower_thread_predicate.cpp | 20 ++++++++++++------- .../jit/codegen/cuda/lower_thread_predicate.h | 3 +-- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 10 ++++++++-- 9 files changed, 49 insertions(+), 23 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 241fdec2edab..5df2143ffb61 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -528,10 +528,8 @@ class CudaKernelGenerator : private kir::IrVisitor { void visit(const kir::BroadcastOp* node) final { TORCH_INTERNAL_ASSERT(node->out()->isA()); - const auto tensor_index = node->out()->as(); - const ParallelTypeBitmap domains = ir_utils::getParallelBroadcastDomains( - tensor_index->view()->fuserTv(), kernel_->predicateMap()); + const ParallelTypeBitmap& domains = node->parallelTypes(); const bool thread_x = domains.get(ParallelType::TIDx); const bool thread_y = domains.get(ParallelType::TIDy); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index ea2b463db399..9058c6263251 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -342,8 +342,12 @@ std::unordered_map ReductionOp:: return parallel_domains; } -BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in) - : Expr(passkey), out_(out), in_(in) { +BroadcastOp::BroadcastOp( + Passkey passkey, + Val* out, + Val* in, + const ParallelTypeBitmap& parallel_types) + : Expr(passkey), out_(out), in_(in), parallel_types_(parallel_types) { TORCH_CHECK(in->isA() || in->isA()); TORCH_CHECK(out->isA() || out->isA()); addOutput(out); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 100e508383ea..801d8e88ef1f 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -1022,7 +1022,11 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { class TORCH_CUDA_CU_API BroadcastOp final : public Expr { public: - BroadcastOp(Passkey passkey, Val* out, Val* in); + explicit BroadcastOp( + Passkey passkey, + Val* out, + Val* in, + const ParallelTypeBitmap& paralell_types); void accept(IrVisitor* visitor) const override { visitor->visit(this); @@ -1040,9 +1044,14 @@ class TORCH_CUDA_CU_API BroadcastOp final : public Expr { return in_; } + const ParallelTypeBitmap& parallelTypes() const { + return parallel_types_; + } + private: Val* const out_ = nullptr; Val* const in_ = nullptr; + ParallelTypeBitmap parallel_types_; }; //! Allocate is a lower level Node that describes a buffer of memory that diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 1065a2169553..b9e66607fff6 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -133,7 +134,7 @@ void GpuLower::lower() { validateParallelize(fusion_); // Compute thread predicates - ThreadPredicateMap preds(fusion_); + thread_preds_.build(fusion_); // Set the kernel inputs & outputs for (auto input : fusion_->inputs()) { @@ -162,7 +163,7 @@ void GpuLower::lower() { const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs); const auto unrolled_loops = - UnrollPass::runPass(fusion_, raw_sync_exprs, preds); + UnrollPass::runPass(fusion_, raw_sync_exprs, thread_preds_); // Reuse memory locations if: // TensorView is dynamic shared memory @@ -174,10 +175,10 @@ void GpuLower::lower() { const auto war_sync_exprs = insertWarThreadSynchronization(reuse_mem_exprs); const auto indexed_loops = - IndexLowering::getIndexedExprs(war_sync_exprs, preds); + IndexLowering::getIndexedExprs(war_sync_exprs, thread_preds_); // We now have the lowered expressions, finalize the kernel IR - kernel_->finalize(indexed_loops, preds); + kernel_->finalize(indexed_loops, thread_preds_); } kir::Kernel* GpuLower::kernel() const { @@ -346,8 +347,10 @@ class GpuLower::KernelIrMapper : private OptInConstDispatch { } void handle(const BroadcastOp* node) final { + ParallelTypeBitmap parallel_types = ir_utils::getParallelBroadcastDomains( + node->out()->as(), gpu_lower_->thread_preds_); const auto lowered_node = ir_builder_.create( - lowerValue(node->out()), lowerValue(node->in())); + lowerValue(node->out()), lowerValue(node->in()), parallel_types); TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index c1b730bdd6fa..4cf93f4202de 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -79,6 +79,7 @@ class TORCH_CUDA_CU_API GpuLower { ComputeAtMap ca_index_map_; ComputeAtMap ca_parallel_map_; TrivialReductionInfo trivial_reduction_info_; + ThreadPredicateMap thread_preds_; Fusion* fusion_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 1fc3318c3ad5..70113ca8dcc6 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -418,7 +418,7 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop)); const auto out = lowerDstIndex(bop->out()); const auto in = lowerSrcIndex(bop->in(), bop->out()); - pushBack(ir_builder_.create(out, in)); + pushBack(ir_builder_.create(out, in, bop->parallelTypes())); } void IndexLowering::visit(const kir::Allocate* allocate) { diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 3e53d641bab3..1823dad57b30 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -163,12 +163,18 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { ParallelTypeBitmap id_ptypes; for (auto id : tv_inp->domain()->domain()) { - if (id->isThread()) { - id_ptypes.set(id->getParallelType(), true); + const auto gpu_lower = GpuLower::current(); + TORCH_INTERNAL_ASSERT( + gpu_lower != nullptr, "GpuLower is required but not found"); + + auto par_concrete_id = gpu_lower->caParallelMap().getConcreteMappedID(id); + auto par_type = par_concrete_id->getParallelType(); + if (isParallelTypeThread(par_type)) { + id_ptypes.set(par_type, true); if (id->isReduction()) - id_reductions.set(id->getParallelType(), true); + id_reductions.set(par_type, true); if (id->isBroadcast()) - id_bcasts.set(id->getParallelType(), true); + id_bcasts.set(par_type, true); } } @@ -220,16 +226,16 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { } } -ThreadPredicateMap::ThreadPredicateMap(Fusion* fusion) : fusion_(fusion) { +void ThreadPredicateMap::build(Fusion* fusion) { FUSER_PERF_SCOPE("ThreadPredicateMap"); // Initialize mapping for input tensors - for (auto inp : fusion_->inputs()) { + for (auto inp : fusion->inputs()) { if (auto tv = dynamic_cast(inp)) { insert(tv, ParallelTypeBitmap(), SourceMap()); } } - for (auto expr : fusion_->exprs()) { + for (auto expr : fusion->exprs()) { updateBitSet(expr); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index c56f87afc16d..18135e059c2c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -43,7 +43,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { using const_iterator = MapType::const_iterator; - explicit ThreadPredicateMap(Fusion* fusion); + void build(Fusion* fusion); // TODO(kir): these methods are only used by getParallelBroadcastDomains() ? const_iterator find(const TensorView* tv) const; @@ -68,7 +68,6 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { void insert(const TensorView* tv, const PredAndSource& pred_and_src); private: - Fusion* fusion_ = nullptr; MapType thread_predicates_; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 65f3d44c9df0..7ad339b7b04c 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -146,6 +146,10 @@ ParallelTypeBitmap getParallelBroadcastDomains( return ParallelTypeBitmap(); } + const auto gpu_lower = GpuLower::current(); + TORCH_INTERNAL_ASSERT( + gpu_lower != nullptr, "GpuLower is required but not found"); + ParallelTypeBitmap parallel_broadcast; const auto& iter_domains = tv->domain()->domain(); @@ -160,8 +164,10 @@ ParallelTypeBitmap getParallelBroadcastDomains( if (!id->isBroadcast()) { continue; } - if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { - parallel_broadcast.set(id->getParallelType(), true); + auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID(id); + if (concrete_id->isBlockDim() || + (!output_smem && concrete_id->isThreadDim())) { + parallel_broadcast.set(concrete_id->getParallelType(), true); } } From fc06ec3c0a1f1fdfef59886989b3cf9eab2f0307 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Mar 2021 09:48:35 -0700 Subject: [PATCH 5/7] Revert "Add parallelism information to kir::BroadcastOp" This reverts commit 2caed183f1639f2ce2373b318f66938a5e70a463. --- torch/csrc/jit/codegen/cuda/codegen.cpp | 4 +++- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 8 ++------ torch/csrc/jit/codegen/cuda/kernel_ir.h | 11 +--------- torch/csrc/jit/codegen/cuda/lower2device.cpp | 13 +++++------- torch/csrc/jit/codegen/cuda/lower2device.h | 1 - torch/csrc/jit/codegen/cuda/lower_index.cpp | 2 +- .../codegen/cuda/lower_thread_predicate.cpp | 20 +++++++------------ .../jit/codegen/cuda/lower_thread_predicate.h | 3 ++- torch/csrc/jit/codegen/cuda/lower_utils.cpp | 10 ++-------- 9 files changed, 23 insertions(+), 49 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/codegen.cpp b/torch/csrc/jit/codegen/cuda/codegen.cpp index 5df2143ffb61..241fdec2edab 100644 --- a/torch/csrc/jit/codegen/cuda/codegen.cpp +++ b/torch/csrc/jit/codegen/cuda/codegen.cpp @@ -528,8 +528,10 @@ class CudaKernelGenerator : private kir::IrVisitor { void visit(const kir::BroadcastOp* node) final { TORCH_INTERNAL_ASSERT(node->out()->isA()); + const auto tensor_index = node->out()->as(); - const ParallelTypeBitmap& domains = node->parallelTypes(); + const ParallelTypeBitmap domains = ir_utils::getParallelBroadcastDomains( + tensor_index->view()->fuserTv(), kernel_->predicateMap()); const bool thread_x = domains.get(ParallelType::TIDx); const bool thread_y = domains.get(ParallelType::TIDy); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 9058c6263251..ea2b463db399 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -342,12 +342,8 @@ std::unordered_map ReductionOp:: return parallel_domains; } -BroadcastOp::BroadcastOp( - Passkey passkey, - Val* out, - Val* in, - const ParallelTypeBitmap& parallel_types) - : Expr(passkey), out_(out), in_(in), parallel_types_(parallel_types) { +BroadcastOp::BroadcastOp(Passkey passkey, Val* out, Val* in) + : Expr(passkey), out_(out), in_(in) { TORCH_CHECK(in->isA() || in->isA()); TORCH_CHECK(out->isA() || out->isA()); addOutput(out); diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 801d8e88ef1f..100e508383ea 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -1022,11 +1022,7 @@ class TORCH_CUDA_CU_API TensorIndex final : public Val { class TORCH_CUDA_CU_API BroadcastOp final : public Expr { public: - explicit BroadcastOp( - Passkey passkey, - Val* out, - Val* in, - const ParallelTypeBitmap& paralell_types); + BroadcastOp(Passkey passkey, Val* out, Val* in); void accept(IrVisitor* visitor) const override { visitor->visit(this); @@ -1044,14 +1040,9 @@ class TORCH_CUDA_CU_API BroadcastOp final : public Expr { return in_; } - const ParallelTypeBitmap& parallelTypes() const { - return parallel_types_; - } - private: Val* const out_ = nullptr; Val* const in_ = nullptr; - ParallelTypeBitmap parallel_types_; }; //! Allocate is a lower level Node that describes a buffer of memory that diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index b9e66607fff6..1065a2169553 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -134,7 +133,7 @@ void GpuLower::lower() { validateParallelize(fusion_); // Compute thread predicates - thread_preds_.build(fusion_); + ThreadPredicateMap preds(fusion_); // Set the kernel inputs & outputs for (auto input : fusion_->inputs()) { @@ -163,7 +162,7 @@ void GpuLower::lower() { const auto raw_sync_exprs = insertRawThreadSynchronization(alloced_exprs); const auto unrolled_loops = - UnrollPass::runPass(fusion_, raw_sync_exprs, thread_preds_); + UnrollPass::runPass(fusion_, raw_sync_exprs, preds); // Reuse memory locations if: // TensorView is dynamic shared memory @@ -175,10 +174,10 @@ void GpuLower::lower() { const auto war_sync_exprs = insertWarThreadSynchronization(reuse_mem_exprs); const auto indexed_loops = - IndexLowering::getIndexedExprs(war_sync_exprs, thread_preds_); + IndexLowering::getIndexedExprs(war_sync_exprs, preds); // We now have the lowered expressions, finalize the kernel IR - kernel_->finalize(indexed_loops, thread_preds_); + kernel_->finalize(indexed_loops, preds); } kir::Kernel* GpuLower::kernel() const { @@ -347,10 +346,8 @@ class GpuLower::KernelIrMapper : private OptInConstDispatch { } void handle(const BroadcastOp* node) final { - ParallelTypeBitmap parallel_types = ir_utils::getParallelBroadcastDomains( - node->out()->as(), gpu_lower_->thread_preds_); const auto lowered_node = ir_builder_.create( - lowerValue(node->out()), lowerValue(node->in()), parallel_types); + lowerValue(node->out()), lowerValue(node->in())); TORCH_CHECK(gpu_lower_->kir_expr_map_.insert({node, lowered_node}).second); } diff --git a/torch/csrc/jit/codegen/cuda/lower2device.h b/torch/csrc/jit/codegen/cuda/lower2device.h index 4cf93f4202de..c1b730bdd6fa 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.h +++ b/torch/csrc/jit/codegen/cuda/lower2device.h @@ -79,7 +79,6 @@ class TORCH_CUDA_CU_API GpuLower { ComputeAtMap ca_index_map_; ComputeAtMap ca_parallel_map_; TrivialReductionInfo trivial_reduction_info_; - ThreadPredicateMap thread_preds_; Fusion* fusion_ = nullptr; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_index.cpp b/torch/csrc/jit/codegen/cuda/lower_index.cpp index 70113ca8dcc6..1fc3318c3ad5 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index.cpp @@ -418,7 +418,7 @@ void IndexLowering::visit(const kir::BroadcastOp* bop) { TORCH_INTERNAL_ASSERT(ir_utils::isTVOp(bop)); const auto out = lowerDstIndex(bop->out()); const auto in = lowerSrcIndex(bop->in(), bop->out()); - pushBack(ir_builder_.create(out, in, bop->parallelTypes())); + pushBack(ir_builder_.create(out, in)); } void IndexLowering::visit(const kir::Allocate* allocate) { diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp index 1823dad57b30..3e53d641bab3 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp @@ -163,18 +163,12 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { ParallelTypeBitmap id_ptypes; for (auto id : tv_inp->domain()->domain()) { - const auto gpu_lower = GpuLower::current(); - TORCH_INTERNAL_ASSERT( - gpu_lower != nullptr, "GpuLower is required but not found"); - - auto par_concrete_id = gpu_lower->caParallelMap().getConcreteMappedID(id); - auto par_type = par_concrete_id->getParallelType(); - if (isParallelTypeThread(par_type)) { - id_ptypes.set(par_type, true); + if (id->isThread()) { + id_ptypes.set(id->getParallelType(), true); if (id->isReduction()) - id_reductions.set(par_type, true); + id_reductions.set(id->getParallelType(), true); if (id->isBroadcast()) - id_bcasts.set(par_type, true); + id_bcasts.set(id->getParallelType(), true); } } @@ -226,16 +220,16 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) { } } -void ThreadPredicateMap::build(Fusion* fusion) { +ThreadPredicateMap::ThreadPredicateMap(Fusion* fusion) : fusion_(fusion) { FUSER_PERF_SCOPE("ThreadPredicateMap"); // Initialize mapping for input tensors - for (auto inp : fusion->inputs()) { + for (auto inp : fusion_->inputs()) { if (auto tv = dynamic_cast(inp)) { insert(tv, ParallelTypeBitmap(), SourceMap()); } } - for (auto expr : fusion->exprs()) { + for (auto expr : fusion_->exprs()) { updateBitSet(expr); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h index 18135e059c2c..c56f87afc16d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h +++ b/torch/csrc/jit/codegen/cuda/lower_thread_predicate.h @@ -43,7 +43,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { using const_iterator = MapType::const_iterator; - void build(Fusion* fusion); + explicit ThreadPredicateMap(Fusion* fusion); // TODO(kir): these methods are only used by getParallelBroadcastDomains() ? const_iterator find(const TensorView* tv) const; @@ -68,6 +68,7 @@ class TORCH_CUDA_CU_API ThreadPredicateMap { void insert(const TensorView* tv, const PredAndSource& pred_and_src); private: + Fusion* fusion_ = nullptr; MapType thread_predicates_; }; diff --git a/torch/csrc/jit/codegen/cuda/lower_utils.cpp b/torch/csrc/jit/codegen/cuda/lower_utils.cpp index 7ad339b7b04c..65f3d44c9df0 100644 --- a/torch/csrc/jit/codegen/cuda/lower_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_utils.cpp @@ -146,10 +146,6 @@ ParallelTypeBitmap getParallelBroadcastDomains( return ParallelTypeBitmap(); } - const auto gpu_lower = GpuLower::current(); - TORCH_INTERNAL_ASSERT( - gpu_lower != nullptr, "GpuLower is required but not found"); - ParallelTypeBitmap parallel_broadcast; const auto& iter_domains = tv->domain()->domain(); @@ -164,10 +160,8 @@ ParallelTypeBitmap getParallelBroadcastDomains( if (!id->isBroadcast()) { continue; } - auto concrete_id = gpu_lower->caParallelMap().getConcreteMappedID(id); - if (concrete_id->isBlockDim() || - (!output_smem && concrete_id->isThreadDim())) { - parallel_broadcast.set(concrete_id->getParallelType(), true); + if (id->isBlockDim() || (!output_smem && id->isThreadDim())) { + parallel_broadcast.set(id->getParallelType(), true); } } From 3c7c1470f01abfd4f41ae09ec419a060a73b9a81 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 16 Mar 2021 16:41:01 -0700 Subject: [PATCH 6/7] Parallelize all IterDomains when inferred by computeAt relationships --- torch/csrc/jit/codegen/cuda/compute_at_map.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 0a1fbffc9cf5..55f88fc59071 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -357,18 +357,15 @@ void ComputeAtMap::build(Fusion* fusion, GpuLower* gpu_lower) { TORCH_INTERNAL_ASSERT( concrete_id != nullptr, "Could not concretize an IterDomain set."); - // If parallel mode, parallelize the the concrete id - // TODO: Would be good to simply keep a parallelization map and make lookups - // to it through lowering. - if (mapping_mode_ == MappingMode::PARALLEL) { - auto parallel_map_it = parallel_type_map_.find(set); - if (parallel_map_it != parallel_type_map_.end()) { - concrete_id->parallelize(parallel_map_it->second); - } - } - for (auto id : *set) { concrete_id_map_[id] = concrete_id; + if (mapping_mode_ == MappingMode::PARALLEL) { + auto parallel_map_it = parallel_type_map_.find(set); + // Parallelize all IterDomains to simplify lowering and codegen + if (parallel_map_it != parallel_type_map_.end()) { + id->parallelize(parallel_map_it->second); + } + } } } From c0fa0c9ac4242c75a2fb061991c4076e3e258019 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 19 Mar 2021 10:00:10 -0700 Subject: [PATCH 7/7] Do not substiutte kir::IterDomain::extent_ with parallel dimensions --- torch/csrc/jit/codegen/cuda/kernel_ir.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index ea2b463db399..b893f81d3040 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -97,14 +97,12 @@ IterDomain::IterDomain( setName(iter_domain->name()); } +//! Note that the parallel dimension, if available, may be different +//! from the actual extent of this IterDomain as the parallel +//! dimension is determined by the largest extent of IterDomains +//! sharing the same loop. Val* IterDomain::extent() const { TORCH_INTERNAL_ASSERT(extent_ != nullptr); - if (isThread()) { - if (extent_->isScalar() && extent_->isConst()) { - return extent_; - } - return NamedScalar::getParallelDim(parallelType()); - } return extent_; }