From 83bad72c62d36f8e55d4e56c3369890b1ab03a9e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 23 May 2020 14:36:13 -0400 Subject: [PATCH 1/8] Rip out reorder IR nodes. --- torch/csrc/jit/codegen/cuda/dispatch.cpp | 8 - torch/csrc/jit/codegen/cuda/dispatch.h | 13 - torch/csrc/jit/codegen/cuda/fusion.cpp | 2 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 35 +- torch/csrc/jit/codegen/cuda/index_compute.h | 1 - .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 32 -- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 8 - torch/csrc/jit/codegen/cuda/ir_iostream.h | 2 - torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 27 -- torch/csrc/jit/codegen/cuda/ir_printer.h | 1 - torch/csrc/jit/codegen/cuda/mutator.cpp | 11 - .../csrc/jit/codegen/cuda/transform_iter.cpp | 458 ------------------ torch/csrc/jit/codegen/cuda/transform_iter.h | 4 +- .../jit/codegen/cuda/transform_rfactor.cpp | 13 +- torch/csrc/jit/codegen/cuda/type.cpp | 3 +- torch/csrc/jit/codegen/cuda/type.h | 3 +- 16 files changed, 9 insertions(+), 612 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/dispatch.cpp b/torch/csrc/jit/codegen/cuda/dispatch.cpp index 408be45c2bf4a..0c2b1560a084c 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.cpp +++ b/torch/csrc/jit/codegen/cuda/dispatch.cpp @@ -90,9 +90,6 @@ void Expr::dispatch(T handler, Expr* expr) { case ExprType::Merge: ptr(handler)->handle(static_cast(expr)); return; - case ExprType::Reorder: - ptr(handler)->handle(static_cast(expr)); - return; case ExprType::UnaryOp: ptr(handler)->handle(static_cast(expr)); return; @@ -183,9 +180,6 @@ void Expr::constDispatch(T handler, const Expr* expr) { case ExprType::Merge: ptr(handler)->handle(static_cast(expr)); return; - case ExprType::Reorder: - ptr(handler)->handle(static_cast(expr)); - return; case ExprType::UnaryOp: ptr(handler)->handle(static_cast(expr)); return; @@ -276,8 +270,6 @@ Statement* Expr::mutatorDispatch(T mutator, Expr* expr) { return ptr(mutator)->mutate(static_cast(expr)); case ExprType::Merge: return ptr(mutator)->mutate(static_cast(expr)); - case ExprType::Reorder: - return ptr(mutator)->mutate(static_cast(expr)); case ExprType::UnaryOp: return ptr(mutator)->mutate(static_cast(expr)); case ExprType::BinaryOp: diff --git a/torch/csrc/jit/codegen/cuda/dispatch.h b/torch/csrc/jit/codegen/cuda/dispatch.h index b73220e111cdf..e13f38cdc1ddb 100644 --- a/torch/csrc/jit/codegen/cuda/dispatch.h +++ b/torch/csrc/jit/codegen/cuda/dispatch.h @@ -69,7 +69,6 @@ struct NamedScalar; // Exprs struct Split; struct Merge; -struct Reorder; struct UnaryOp; struct BinaryOp; struct TernaryOp; @@ -112,7 +111,6 @@ struct TORCH_CUDA_API OptOutConstDispatch { // Exprs virtual void handle(const Split* const) {} virtual void handle(const Merge* const) {} - virtual void handle(const Reorder* const) {} virtual void handle(const UnaryOp* const) {} virtual void handle(const BinaryOp* const) {} virtual void handle(const TernaryOp* const) {} @@ -152,7 +150,6 @@ struct TORCH_CUDA_API OptOutDispatch { // Exprs virtual void handle(Split*) {} virtual void handle(Merge*) {} - virtual void handle(Reorder*) {} virtual void handle(UnaryOp*) {} virtual void handle(BinaryOp*) {} virtual void handle(TernaryOp*) {} @@ -214,9 +211,6 @@ struct TORCH_CUDA_API OptInConstDispatch { virtual void handle(const Merge* const) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); } - virtual void handle(const Reorder* const) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder."); - } virtual void handle(const UnaryOp* const) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); } @@ -294,9 +288,6 @@ struct TORCH_CUDA_API OptInDispatch { virtual void handle(Merge*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Merge."); } - virtual void handle(Reorder*) { - TORCH_INTERNAL_ASSERT(false, "Handle not overriden for Reorder."); - } virtual void handle(UnaryOp*) { TORCH_INTERNAL_ASSERT(false, "Handle not overriden for UnaryOp."); } @@ -376,7 +367,6 @@ struct TORCH_CUDA_API OptOutMutator { // Exprs virtual Statement* mutate(Split*); virtual Statement* mutate(Merge*); - virtual Statement* mutate(Reorder*); virtual Statement* mutate(UnaryOp*); virtual Statement* mutate(BinaryOp*); virtual Statement* mutate(TernaryOp*); @@ -445,9 +435,6 @@ struct TORCH_CUDA_API OptInMutator { virtual Statement* mutate(Merge*) { TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Merge."); } - virtual Statement* mutate(Reorder*) { - TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for Reorder."); - } virtual Statement* mutate(UnaryOp*) { TORCH_INTERNAL_ASSERT(false, "Mutate not overriden for UnaryOp."); } diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 983c93f5c583e..750c8e32e07d9 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -142,7 +142,7 @@ void Fusion::addOutput(Val* const output) { if (tv->getRootDomain() ->hasBroadcast()) // Go to the root as we can merge bcast and // non-bcast dims, making a non-bcast dim. - TORCH_CHECK( + TORCH_CHECK( // Should we warn instead? false, output, " cannot be registered as an output as it has a broadcast axis."); diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index ade7abd171b63..631b9432b2817 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -30,37 +30,6 @@ TensorDomain* IndexCompute::replayBackward(Merge* merge, TensorDomain*) { return merge->in(); } -TensorDomain* IndexCompute::replayBackward(Reorder* reorder, TensorDomain*) { - // new2old[new_pos] = old_pos Generate new old2new map - const std::vector& new2old = reorder->new2old(); - - std::vector reordered_indices; - - // Reverse the map so we can simply push back into reordered_indices - // old2new[old_pos] = new_pos - std::vector old2new(new2old.size(), -1); - - for (decltype(new2old.size()) i = 0; i < new2old.size(); i++) { - int new_pos = i; - int old_pos = new2old[i]; - TORCH_INTERNAL_ASSERT( - new_pos >= 0 && (unsigned int)new_pos < indices.size() && - old_pos >= 0 && (unsigned int)old_pos < indices.size(), - "Hit an invalid reorder transformation during IndexCompute," - " at least one move position is not within bounds."); - old2new[old_pos] = new_pos; - } - for (auto new_pos : old2new) { - // int new_pos = old2new[i]; - // int old_pos = i; - // reordered_indices[old_pos] = indices[new_pos]; - reordered_indices.push_back(indices[new_pos]); - } - - indices = reordered_indices; - return reorder->in(); -} - TensorDomain* IndexCompute::runBackward(std::vector history) { TensorDomain* running_td = nullptr; for (auto it = history.rbegin(); it != history.rend(); it++) @@ -81,7 +50,7 @@ IndexCompute::IndexCompute(TensorDomain* td, std::vector _indices) // If we need to ignore the reduction dimensions because a tensor is // being consumed, not produced, then insert dummy dimensions in the - // indices for bookkeeping while replaying split/merge/reorder operations. + // indices for bookkeeping while replaying split/merge operations. if (exclude_reduction) for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) if (td->axis(i)->isReduction()) @@ -91,7 +60,7 @@ IndexCompute::IndexCompute(TensorDomain* td, std::vector _indices) indices.size() == td->nDims(), "Attempted to modify indices for IndexCompute, but didn't work."); - // Run the split/merge/reorder operations backwards. This will + // Run the split/merge operations backwards. This will // Modify std::vector indices so it can be used to index // the root TensorDomain which should now match the physical axes. TensorDomain* root = TransformIter::getRoot(td); diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index aa47d63523998..ecbeb2eb25ee8 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -59,7 +59,6 @@ struct IndexCompute : public TransformIter { private: TensorDomain* replayBackward(Split*, TensorDomain*) override; TensorDomain* replayBackward(Merge*, TensorDomain*) override; - TensorDomain* replayBackward(Reorder*, TensorDomain*) override; TensorDomain* runBackward(std::vector history); diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 7e584f28dc53f..95774cba4ae47 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -474,38 +474,6 @@ struct TORCH_CUDA_API Merge : public Expr { int axis_; }; -/* - * Reorder the IterDomains of a tensor domain with the map - * new2old[new_position] = old_position - */ -struct TORCH_CUDA_API Reorder : public Expr { - ~Reorder() = default; - Reorder(TensorDomain* _out, TensorDomain* _in, std::vector _new2old); - - Reorder(const Reorder& other) = delete; - Reorder& operator=(const Reorder& other) = delete; - - Reorder(Reorder&& other) = delete; - Reorder& operator=(Reorder&& other) = delete; - - TensorDomain* out() const noexcept { - return out_; - } - TensorDomain* in() const noexcept { - return in_; - } - const std::vector& new2old() const noexcept { - return new2old_; - } - - bool sameAs(const Reorder* const other) const; - - private: - TensorDomain* const out_; - TensorDomain* const in_; - const std::vector new2old_; -}; - /* * ForLoop provides scoping around an int iterator from 0 to range. Exprs placed * in its body are considered inside the scope of the for loop. In the future diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 4299d40a01473..2008a12fb3ff0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -516,14 +516,6 @@ void IRPrinter::handle(const Merge* const m) { os << "\n"; } -void IRPrinter::handle(const Reorder* const ro) { - os << "Reorder: "; - handle(ro->in()); - os << " -> "; - handle(ro->out()); - os << "\n"; -} - namespace { struct ReductionOps : OptOutDispatch { diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.h b/torch/csrc/jit/codegen/cuda/ir_iostream.h index 0bd1cd81c0751..5034dbf58b502 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.h +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.h @@ -35,7 +35,6 @@ struct TensorContiguity; struct Split; struct Merge; -struct Reorder; struct Bool; struct Float; @@ -119,7 +118,6 @@ struct TORCH_CUDA_API IRPrinter : public OptInConstDispatch { virtual void handle(const Split* const) override; virtual void handle(const Merge* const) override; - virtual void handle(const Reorder* const) override; void print_inline(const Statement* const stmt) { bool prev = print_inline_; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3b71a432c5a55..345be55285ba2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -531,7 +531,6 @@ TensorDomain* TensorDomain::reorder( [this](int i) -> IterDomain* { return this->axis(i); }); TensorDomain* reordered_td = new TensorDomain(reordered_domain); - new Reorder(reordered_td, this, new2old); return reordered_td; } @@ -617,32 +616,6 @@ bool Merge::sameAs(const Merge* const other) const { axis() == other->axis()); } -Reorder::Reorder( - TensorDomain* _out, - TensorDomain* _in, - std::vector _new2old) - : Expr(ExprType::Reorder), - out_(_out), - in_(_in), - new2old_(std::move(_new2old)) { - auto ndims = in_->nDims(); - TORCH_INTERNAL_ASSERT( - std::none_of( - _new2old.begin(), - _new2old.end(), - [ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }), - "Invalid reorder node, axes found less than 0 or >= ndims."); - - addOutput(_out); - addInput(_in); - this->name_ = FusionGuard::getCurFusion()->registerExpr(this); -} - -bool Reorder::sameAs(const Reorder* const other) const { - // Implicitly in and out matching means new2old matches - return (out()->sameAs(other->out()) && in()->sameAs(other->in())); -} - ForLoop::ForLoop( Val* _index, IterDomain* _iter_domain, diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index d074784107b9f..5db8f3537138e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -28,7 +28,6 @@ class TORCH_CUDA_API IRMathPrinter : public IRPrinter { void handle(const Split* const) override {} void handle(const Merge* const) override {} - void handle(const Reorder* const) override {} void handle(Fusion* f) override { IRPrinter::handle(f); diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index ac84d11379325..9b792954cf737 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -156,17 +156,6 @@ Statement* OptOutMutator::mutate(Merge* m) { return new Merge(o, i, m->axis()); } -Statement* OptOutMutator::mutate(Reorder* ro) { - TensorDomain* o = static_cast(mutateAsVal(ro->out())); - TensorDomain* i = static_cast(mutateAsVal(ro->in())); - - if (o->sameAs(ro->out()) && i->sameAs(ro->in())) - return ro; - - FusionGuard::getCurFusion()->removeExpr(ro); - return new Reorder(o, i, ro->new2old()); -} - Statement* OptOutMutator::mutate(UnaryOp* uop) { Val* out = mutateAsVal(uop->out())->asVal(); Val* in = mutateAsVal(uop->in())->asVal(); diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index a77626c195dfe..e30d6321ae294 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -15,12 +15,6 @@ TensorDomain* TransformIter::replayBackward(Merge* merge, TensorDomain* td) { return merge->in(); } -TensorDomain* TransformIter::replayBackward( - Reorder* reorder, - TensorDomain* td) { - return reorder->in(); -} - TensorDomain* TransformIter::replayBackward(Expr* expr, TensorDomain* td) { TORCH_INTERNAL_ASSERT( expr->isExpr(), @@ -30,8 +24,6 @@ TensorDomain* TransformIter::replayBackward(Expr* expr, TensorDomain* td) { return replayBackward(static_cast(expr), td); case (ExprType::Merge): return replayBackward(static_cast(expr), td); - case (ExprType::Reorder): - return replayBackward(static_cast(expr), td); default: TORCH_INTERNAL_ASSERT( false, "Could not detect expr type in replayBackward."); @@ -95,13 +87,6 @@ TensorDomain* TransformIter::replay(Merge* expr, TensorDomain* td) { return td->merge(expr->axis()); } -TensorDomain* TransformIter::replay(Reorder* expr, TensorDomain* td) { - std::unordered_map old2new; - for (decltype(expr->new2old().size()) i{0}; i < expr->new2old().size(); i++) - old2new[expr->new2old()[i]] = i; - return td->reorder(old2new); -} - TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) { TORCH_INTERNAL_ASSERT(expr->isExpr()); switch (*(expr->getExprType())) { @@ -109,8 +94,6 @@ TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) { return replay(static_cast(expr), td); case (ExprType::Merge): return replay(static_cast(expr), td); - case (ExprType::Reorder): - return replay(static_cast(expr), td); default: TORCH_INTERNAL_ASSERT(false, "Could not detect expr type in replay."); } @@ -173,25 +156,6 @@ struct Influence : public TransformIter { return merge->in(); } - TensorDomain* replayBackward(Reorder* reorder, TensorDomain* td) override { - // new2old[new_pos] = old_pos Generate new old2new map - const std::vector& new2old = reorder->new2old(); - - std::vector reorder_influence(influence.size(), false); - for (decltype(new2old.size()) i = 0; i < new2old.size(); i++) { - int new_pos = i; - int old_pos = new2old[i]; - TORCH_INTERNAL_ASSERT( - new_pos < (int)influence.size() && - old_pos < (int)reorder_influence.size(), - "Error during replay backwards, td/influence size mismatch."); - reorder_influence[old_pos] = influence[new_pos]; - } - - influence = reorder_influence; - return reorder->in(); - } - // FORWARD INFLUENCE TensorDomain* replay(Split* split, TensorDomain* td) override { @@ -213,25 +177,6 @@ struct Influence : public TransformIter { return nullptr; } - TensorDomain* replay(Reorder* reorder, TensorDomain* td) override { - // new2old[new_pos] = old_pos Generate new old2new map - const std::vector& new2old = reorder->new2old(); - - std::vector reorder_influence(influence.size(), false); - for (decltype(new2old.size()) i = 0; i < new2old.size(); i++) { - int new_pos = i; - int old_pos = new2old[i]; - TORCH_INTERNAL_ASSERT( - new_pos < (int)influence.size() && - old_pos < (int)reorder_influence.size(), - "Error during replay, td/influence size mismatch."); - reorder_influence[new_pos] = influence[old_pos]; - } - - influence = reorder_influence; - return nullptr; - } - // INTERFACE std::vector influence; @@ -356,163 +301,6 @@ struct Replay : public TransformIter { return td->merge(axis); } - // This transform requires reordering axes in td, then updating the axis_map - // We want to replay axes in td, not marked with -1, to match that in the - // provided reorder. This must be done because there may be a reorder that's - // required for a merge, as merge is specified by the first axes and merges - // the next consecutive axis. - // - // Once we transform td, we need to update axis_map or the mapping to provide: - // reorder->in()->axis(i) == reorder->axis(axis_map[i]) - // - // Axes not marked with -1 should be placed in the outer most dimensions in - // the relative order specified by reorder. Remaining axes should be placed in - // the inner most dimensions maintaining their original relative positioning. - TensorDomain* replay(Reorder* reorder, TensorDomain* td) override { - // convert to old2new as it makes this easier to do, and we need that map - // anyways in the end to replay reorder - const std::vector& new2old_orig = reorder->new2old(); - std::vector old2new_orig(new2old_orig.size()); - for (decltype(new2old_orig.size()) i{0}; i < new2old_orig.size(); i++) - old2new_orig[new2old_orig[i]] = i; - - // old2new_orig: reorder->in()->axis(i) == - // reorder->out()->axis(old2new_orig[i]) - - // We would like old2new: td->axis(i) == td_out->axis(old2new[i]) - // if td->axis(i) will participate in the reorder defined by "reorder" - auto extent = std::max(reorder->in()->nDims(), td->nDims()); - std::vector old2new(extent, -1); - for (decltype(old2new_orig.size()) i{0}; i < old2new_orig.size(); i++) { - int old_pos = axis_map[i]; - int new_pos = old2new_orig[i]; - if (old_pos != -1) - old2new[old_pos] = new_pos; - } - - // We want to push to the left the new positions in td_out therefore if our - // map looks like: - // - // Going to move all new_pos to the left so there's no gaps, for example if - // we have: old2new = 2 -1 4 -1 0 (0 -> 2, 2 -> 4, 4->0) we will the new - // positions down to: old2new = 1 -1 2 -1 0 (0 -> 1, 2 -> 2, 4->0) - // 0 -1 -1 3 -1 -1 6 - // 0 -1 -1 0 -1 -1 0 - // 0 -1 -2 -2 -3 -4 -4 - // offset[0] = 0 0->0 - // offset[3] = -2 3->1 - // offset[6] = -4 6->2 - // -------------------- - // -1 -1 -1 3 -1 -1 6 - // -1 -1 -1 0 -1 -1 0 - // -1 -2 -3 -3 -4 -5 -5 - // offset[3] = -3 3->0 - // offset[6] = -5 6->1 - - std::vector offset(old2new.size(), -1); - for (decltype(offset.size()) i{0}; i < offset.size(); i++) { - // If we process this axis - if (old2new[i] != -1) - // we wouldn't offset based on this value - offset[old2new[i]] = 0; - } - - // Prefix sum offset - for (decltype(offset.size()) i{1}; i < offset.size(); i++) { - offset[i] += offset[i - 1]; - } - // Offset is now computed - - // Apply offset - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) { - if (old2new[i] == -1) - continue; - old2new[i] += offset[old2new[i]]; - } - - /* - * old2new should now be the output of what we mention ^^ for offset, i.e. - * old2new = 2 -1 4 -1 0 (0 -> 2, 2 -> 4, 4->0) - * should now be: - * old2new = 1 -1 2 -1 0 (0 -> 1, 2->2, 4->0) - * OR: - * old2new = 1 -1 4 -1 -1 (0->1, 2->4) - * should now be: - * old2new = 0 -1 1 -1 -1 (0->0, 2->1) - * Now we want to fill in -1 positions in relative order, i.e. - * old2new = 1 -1 2 -1 0 (0 -> 1, 2->2, 4->0) - * we want as: - * old2new = 1 3 2 4 0 (0 -> 1, 1->3, 2->2, 3->4, 4->0) - * OR: - * old2new = 0 -1 1 -1 -1 (0->0, 2->1) - * we want as: - * old2new = 0 2 1 3 4 (0->0, 1->2, 2->1, 3->3, 4->4) - */ - // grab the highest index in new_pos - int max_new_pos = *std::max_element(old2new.begin(), old2new.end()); - // Fill in the -1 values in order - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) - if (old2new[i] == -1) - old2new[i] = ++max_new_pos; - old2new.erase(old2new.begin() + td->nDims(), old2new.end()); - - std::set missed_pos; - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) - missed_pos.emplace(i); - - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) { - TORCH_INTERNAL_ASSERT( - missed_pos.find(i) != missed_pos.end(), - "Duplicate entries in replayed reorder map."); - missed_pos.erase(i); - } - - TORCH_INTERNAL_ASSERT( - missed_pos.empty(), - "It's a real mystery how we ended up here. Congrats."); - - // Check if this is a null opt i.e. no actual reordering needs to be done - bool nullopt = true; - std::unordered_map old2new_map; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if (old2new[i] != (int)i) { - nullopt = false; - } - old2new_map[i] = old2new[i]; - } - - // Even if null opt, I'm worried we could have a desynced axis_map as some - // how reorder wasn't a null opt, but after axis_map it was. I'm uncertain - // if this can happen but we can reorder axis_map anyways. - - // HAVE: - // td->axis(old2new[i]) == td_out->axis(i) - // reorder->in()->axis(old2new_orig[i]) = reorder->out()->axis(i) - // reorder->in()->axis(i) ~= td->axis(axis_map[i]) - // NEED: - // td_out->axis(reorder_axis_map[i]) ~= reorder->out()->axis(i) - decltype(axis_map) reordered_axis_map(axis_map.size(), -1); - for (decltype(axis_map.size()) i{0}; i < axis_map.size(); i++) { - int reorder_in_axis = i; - int td_axis = axis_map[i]; - if (td_axis == -1) - continue; - - int reorder_out_axis = old2new_orig[reorder_in_axis]; - int td_axis_out = old2new[td_axis]; - reordered_axis_map[reorder_out_axis] = td_axis_out; - } - - axis_map = reordered_axis_map; - - // If null opt do nothing, return td - if (nullopt) - return td; - - // Rerun reorder - return td->reorder(old2new_map); - } - std::vector axis_map; Replay(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} @@ -679,155 +467,6 @@ struct ReplaySelf : public TransformIter { return replayed; } - // TODO: This is the same as Replay::replay, should work towards code reuse. - TensorDomain* replay(Reorder* reorder, TensorDomain* td) override { - // convert to old2new as it makes this easier to do, and we need that map - // anyways in the end to replay reorder - const std::vector& new2old_orig = reorder->new2old(); - std::vector old2new_orig(new2old_orig.size()); - for (decltype(new2old_orig.size()) i{0}; i < new2old_orig.size(); i++) - old2new_orig[new2old_orig[i]] = i; - - // old2new_orig: reorder->in()->axis(i) == - // reorder->out()->axis(old2new_orig[i]) - - // We would like old2new: td->axis(i) == td_out->axis(old2new[i]) - // if td->axis(i) will participate in the reorder defined by "reorder" - auto extent = reorder->in()->nDims() > td->nDims() ? reorder->in()->nDims() - : td->nDims(); - std::vector old2new(extent, -1); - for (decltype(old2new_orig.size()) i{0}; i < old2new_orig.size(); i++) { - int old_pos = axis_map[i]; - int new_pos = old2new_orig[i]; - if (old_pos != -1) - old2new[old_pos] = new_pos; - } - - // We want to push to the left the new positions in td_out therefore if our - // map looks like: - // - // Going to move all new_pos to the left so there's no gaps, for example if - // we have: old2new = 2 -1 4 -1 0 (0 -> 2, 2 -> 4, 4->0) we will the new - // positions down to: old2new = 1 -1 2 -1 0 (0 -> 1, 2 -> 2, 4->0) - // 0 -1 -1 3 -1 -1 6 - // 0 -1 -1 0 -1 -1 0 - // 0 -1 -2 -2 -3 -4 -4 - // offset[0] = 0 0->0 - // offset[3] = -2 3->1 - // offset[6] = -4 6->2 - // -------------------- - // -1 -1 -1 3 -1 -1 6 - // -1 -1 -1 0 -1 -1 0 - // -1 -2 -3 -3 -4 -5 -5 - // offset[3] = -3 3->0 - // offset[6] = -5 6->1 - - std::vector offset(old2new.size(), -1); - for (decltype(offset.size()) i{0}; i < offset.size(); i++) { - // If we process this axis - if (old2new[i] != -1) - // we wouldn't offset based on this value - offset[old2new[i]] = 0; - } - - // Prefix sum offset - for (decltype(offset.size()) i{1}; i < offset.size(); i++) { - offset[i] = offset[i] + offset[i - 1]; - } - // Offset is now computed - - // Apply offset - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) { - if (old2new[i] == -1) - continue; - old2new[i] += offset[old2new[i]]; - } - - /* - * old2new should now be the output of what we mention ^^ for offset, i.e. - * old2new = 2 -1 4 -1 0 (0 -> 2, 2 -> 4, 4->0) - * should now be: - * old2new = 1 -1 2 -1 0 (0 -> 1, 2->2, 4->0) - * OR: - * old2new = 1 -1 4 -1 -1 (0->1, 2->4) - * should now be: - * old2new = 0 -1 1 -1 -1 (0->0, 2->1) - * Now we want to fill in -1 positions in relative order, i.e. - * old2new = 1 -1 2 -1 0 (0 -> 1, 2->2, 4->0) - * we want as: - * old2new = 1 3 2 4 0 (0 -> 1, 1->3, 2->2, 3->4, 4->0) - * OR: - * old2new = 0 -1 1 -1 -1 (0->0, 2->1) - * we want as: - * old2new = 0 2 1 3 4 (0->0, 1->2, 2->1, 3->3, 4->4) - */ - // grab the highest index in new_pos - int max_new_pos = -1; - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) - max_new_pos = max_new_pos > old2new[i] ? max_new_pos : old2new[i]; - // Fill in the -1 values in order - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) - if (old2new[i] == -1) - old2new[i] = ++max_new_pos; - old2new.erase(old2new.begin() + td->nDims(), old2new.end()); - - std::set missed_pos; - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) - missed_pos.emplace(i); - - for (decltype(old2new.size()) i{0}; i < old2new.size(); i++) { - TORCH_INTERNAL_ASSERT( - missed_pos.find(i) != missed_pos.end(), - "Duplicate entries in replayed reorder map."); - missed_pos.erase(i); - } - - TORCH_INTERNAL_ASSERT( - missed_pos.empty(), - "It's a real mystery how we ended up here. Congrats."); - - // Check if this is a null opt i.e. no actual reordering needs to be done - bool nullopt = true; - std::unordered_map old2new_map; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if (old2new[i] != (int)i) { - nullopt = false; - } - old2new_map[i] = old2new[i]; - } - - // Even if null opt, I'm worried we could have a desynced axis_map as some - // how reorder wasn't a null opt, but after axis_map it was. I'm uncertain - // if this can happen but we can reorder axis_map anyways. - - // HAVE: - // td->axis(old2new[i]) == td_out->axis(i) - // reorder->in()->axis(old2new_orig[i]) = reorder->out()->axis(i) - // reorder->in()->axis(i) ~= td->axis(axis_map[i]) - // NEED: - // td_out->axis(reorder_axis_map[i]) ~= reorder->out()->axis(i) - decltype(axis_map) reordered_axis_map(axis_map.size(), -1); - for (decltype(axis_map.size()) i{0}; i < axis_map.size(); i++) { - int reorder_in_axis = i; - int td_axis = axis_map[i]; - if (td_axis == -1) - continue; - - int reorder_out_axis = old2new_orig[reorder_in_axis]; - int td_axis_out = old2new[td_axis]; - reordered_axis_map[reorder_out_axis] = td_axis_out; - } - - axis_map = reordered_axis_map; - - // If null opt do nothing, return td - if (nullopt) - return td; - - // Rerun reorder - return td->reorder(old2new_map); - } - std::vector axis_map; ReplaySelf(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} @@ -968,96 +607,6 @@ struct TransformBackward : public TransformIter { return replayed_inp; } - TensorDomain* replayBackward(Reorder* reorder, TensorDomain* td) override { - const std::vector& new2old_orig = reorder->new2old(); - - // We want to convert new2old to something with td->nDims which it isn't - // guarenteed to be - std::vector new2old(td->nDims(), -1); - - for (decltype(new2old_orig.size()) i{0}; i < new2old_orig.size(); i++) { - int new_pos = axis_map[i]; // position in td - int old_pos = new2old_orig[i]; // position it should be at before td - - if (new_pos != -1) - new2old[new_pos] = old_pos; - } - - // We want to push to the RIGHT the modified positions in td_in. This is - // in comparison with forward replay which moves modified positions to the - // left. - - // Easiest to start by moving to left like forward replay - - std::vector new2old_offset(new2old_orig.size(), -1); - // Create offset map - for (decltype(new2old.size()) i{0}; i < new2old.size(); i++) - if (new2old[i] != -1) - new2old_offset[new2old[i]] = 0; - - // Prefix sum new2old_offset - for (decltype(new2old_offset.size()) i{1}; i < new2old_offset.size(); i++) - new2old_offset[i] += new2old_offset[i - 1]; - // Apply offset - for (decltype(new2old.size()) i{0}; i < new2old.size(); i++) { - if (new2old[i] == -1) - continue; - new2old[i] += new2old_offset[new2old[i]]; - } - - int max_elem = *std::max_element(new2old.begin(), new2old.end()); - // Now lets push all elements to the right - int right_offset = ((int)td->nDims()) - max_elem - 1; - TORCH_INTERNAL_ASSERT( - right_offset >= 0, - "Error during backward replay, couldn't move modified axes to the right in reorder."); - - // Move to the right - for (decltype(new2old.size()) i{0}; i < new2old.size(); i++) { - if (new2old[i] == -1) - continue; - new2old[i] += right_offset; - } - - // Fill in unmodified positions in order to the left - int it = 0; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) - if (new2old[i] == -1) - new2old[i] = it++; - - // Trim new2old to match td - new2old.erase(new2old.begin() + td->nDims(), new2old.end()); - - // new2old_orig[reorder->out()->pos] = reorder->in()->pos - // axis_map[reorder->out()->pos] = td->pos - // new2old[td->pos] = old_td->pos - // NEED: new_axis_map[reorder->in()->pos] = old_td->pos - - std::vector new_axis_map(axis_map.size(), -1); - for (decltype(new_axis_map.size()) i{0}; i < new_axis_map.size(); i++) { - int reorder_out_pos = i; - int reorder_in_pos = new2old_orig[reorder_out_pos]; - int td_pos = axis_map[reorder_out_pos]; - int old_td_pos = td_pos == -1 ? -1 : new2old[td_pos]; - - new_axis_map[reorder_in_pos] = old_td_pos; - } - - axis_map = new_axis_map; - - std::vector old_td(td->nDims(), nullptr); - for (decltype(new2old.size()) i{0}; i < new2old.size(); i++) { - // new2old[new] = old relative to td - int new_pos = i; // position in td - int old_pos = new2old[i]; // position it should be at before td - old_td[old_pos] = td->axis(new_pos); - } - - TensorDomain* replayed_inp = new TensorDomain(old_td); - new Reorder(td, replayed_inp, new2old); - return replayed_inp; - } - // Entry for backward influence propagation on td following record, history // should be present -> past as you go through the vector TensorDomain* replayBackward( @@ -1101,10 +650,6 @@ struct RFactorRoot : public TransformIter { return merge->out(); } - TensorDomain* replay(Reorder* reorder, TensorDomain*) final { - return reorder->out(); - } - // Replay forward until we hit an operation that doesn't involve an rfactor // axis TensorDomain* runReplay(TensorDomain*, const std::vector& history) @@ -1120,9 +665,6 @@ struct RFactorRoot : public TransformIter { if (found_non_rfactor_op) break; running_op = it; - if ((*it)->getExprType() != ExprType::Reorder) { - last_rfactor_op = it; - } } // We need to make sure the rfactor root is ordered correctly. diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index 4de415447d96e..23d456e13b2df 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -12,7 +12,7 @@ namespace jit { namespace fuser { /* - * TransformIter iterates on the split/merge/reorder graph of TensorDomain + * TransformIter iterates on the split/merge graph of TensorDomain * * Running backward will execute these Exprs in reverse order. If you record * these events (generate_record=true) you can then replay them on another @@ -22,7 +22,6 @@ struct TORCH_CUDA_API TransformIter : public IterVisitor { protected: virtual TensorDomain* replayBackward(Split*, TensorDomain*); virtual TensorDomain* replayBackward(Merge*, TensorDomain*); - virtual TensorDomain* replayBackward(Reorder*, TensorDomain*); // dispatch TensorDomain* replayBackward(Expr*, TensorDomain*); @@ -35,7 +34,6 @@ struct TORCH_CUDA_API TransformIter : public IterVisitor { virtual TensorDomain* replay(Split*, TensorDomain*); virtual TensorDomain* replay(Merge*, TensorDomain*); - virtual TensorDomain* replay(Reorder*, TensorDomain*); // dispatch virtual TensorDomain* replay(Expr*, TensorDomain*); diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 78912a604635d..909b462e9c9fa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -79,8 +79,8 @@ TensorDomain* TransformRFactor::runReplay( // running_td has iteration domains on the right, but to find a valid rfactor // root, we want those to be on the right. If we continued to replay backward - // we likely won't have a valid rfactor root. Lets manually insert a reorder - // so we have a valid rfactor root. + // we likely won't have a valid rfactor root. Lets manually insert a so we + // have a valid rfactor root. std::vector new2old(running_td->nDims()); { @@ -93,15 +93,6 @@ TensorDomain* TransformRFactor::runReplay( if (running_td->axis(i)->isRFactorProduct()) new2old[i] = running_pos++; } - std::vector reorder_axis_map(running_td->nDims()); - std::iota(reorder_axis_map.begin(), reorder_axis_map.end(), 0); - - running_td = TransformIter::replayBackward( - running_td, - // include dummy reorder - {new Reorder( - new TensorDomain(running_td->domain()), running_td, new2old)}, - reorder_axis_map); // how do we find axes // Need axis map from rfactor axes in running_td to corresponding axes in diff --git a/torch/csrc/jit/codegen/cuda/type.cpp b/torch/csrc/jit/codegen/cuda/type.cpp index e81fcbecb4a35..1166058209824 100644 --- a/torch/csrc/jit/codegen/cuda/type.cpp +++ b/torch/csrc/jit/codegen/cuda/type.cpp @@ -64,7 +64,8 @@ static _enum_unordered_map expr_type_string_map{ {ExprType::Allocate, "Allocate"}, {ExprType::Split, "Split"}, {ExprType::Merge, "Merge"}, - {ExprType::Reorder, "Reorder"}}; +}; + static _enum_unordered_map unary_op_type_string_map{ {UnaryOpType::Abs, "fabs"}, {UnaryOpType::Acos, "acosf"}, diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index 1dfdaa7961dfc..760ff58d73479 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -35,8 +35,7 @@ enum class ExprType { IfThenElse, Allocate, Split, - Merge, - Reorder + Merge }; enum class UnaryOpType { From a3f34d7bf84731e3b5ec6bfc6c7e3a995fe42759 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 23 May 2020 14:39:30 -0400 Subject: [PATCH 2/8] Remove incorrect references to reorder. --- torch/csrc/jit/codegen/cuda/arith.cpp | 2 +- torch/csrc/jit/codegen/cuda/index_compute.h | 4 ++-- torch/csrc/jit/codegen/cuda/ir_interface_nodes.h | 2 +- torch/csrc/jit/codegen/cuda/ir_internal_nodes.h | 4 ++-- torch/csrc/jit/codegen/cuda/ir_printer.h | 4 ++-- torch/csrc/jit/codegen/cuda/lower2device.cpp | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 22574abc0428c..dc21dbdda1ab1 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -282,7 +282,7 @@ Val* reductionOp( TORCH_CHECK( tv->getRootDomain() == tv->domain(), - "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/reorder/computeAt."); + "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt."); std::vector uint_axes; for (int axis : axes) { diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index ecbeb2eb25ee8..c237c837aa125 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -8,8 +8,8 @@ * Index compute takes in a list of indices typically generated from the * surrounding for loop nest. The number of indicies are intended to match the * number of dimensions of the incomming TensorView which may have less or more - * dimensions than its root due to split/merge/reorder operations. - * Split/merge/reorder operations are then replayed backwards produce resulting + * dimensions than its root due to split/merge operations. + * Split/merge operations are then replayed backwards produce resulting * indices (based on input indices) that match the root dimension. * * For example with GLOBAL tensor: diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index e461340a286ca..80a9517803cc2 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -160,7 +160,7 @@ struct GPULower; /* * TensorView is our primitive Tensor Type used in code generation. It can be * thought of as representing physical memory, however, its dimensionality is - * modifed as split/merge/reorder/computeAt functions are called. The history of + * modifed as split/merge/computeAt functions are called. The history of * these transformations are kept and used for generating actual code referncing * physical memory. Generally when users are thinking of code generation in * reference to a Tensor, this is the class they should be interacting with. diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 95774cba4ae47..af4b6feae533c 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -28,7 +28,7 @@ namespace fuser { * 1) Casting operation i.e. float(a_val) * 2) Negation i.e. val * -1 * 3) Reduction across a dimension i.e. val.sum(axis=2) - * 4) split/merge/reorder + * 4) split/merge */ struct TORCH_CUDA_API UnaryOp : public Expr { ~UnaryOp() = default; @@ -340,7 +340,7 @@ struct TORCH_CUDA_API IterDomain : public Val { * This is done through the normal interaction of Expr/Val in Fusion. i.e. if we * want to know the previous operation generating a particular TensorDomain we * can simply call FusionGuard::getCurFusion()->origin(a_tensor_domain) which - * should give us an operation in the list [split, merge, reorder] or similar + * should give us an operation in the list [split, merge] or similar * operations that take in a TensorDomain, applies a transformation and outputs * a tensor domain. */ diff --git a/torch/csrc/jit/codegen/cuda/ir_printer.h b/torch/csrc/jit/codegen/cuda/ir_printer.h index 5db8f3537138e..60c21332edb7e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_printer.h +++ b/torch/csrc/jit/codegen/cuda/ir_printer.h @@ -10,10 +10,10 @@ * IRMathPrinter and IRTransformPrinter allow the splitting up of fusion print * functions. IRMathPrinter as its name implies focuses soley on what tensor * computations are taking place. Resulting TensorView math will reflect the - * series of split/merge/reorder/computeAts that have taken place, however these + * series of split/merge/computeAts that have taken place, however these * nodes will not be displayed in what is printed. IRTransformPrinter does not * print any mathematical functions and only lists the series of - * split/merge/reorder calls that were made. Both of these printing methods are + * split/merge calls that were made. Both of these printing methods are * quite verbose on purpose as to show accurately what is represented in the IR * of a fusion. */ diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index ff3e8386973b4..e33bdae37996d 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -260,7 +260,7 @@ void GPULower::replaceSizes() { // view. For example T0 may be translated to T9. We don't want our new // variable to be T0->size[...] we need it to be T9->size[...] // - // This could be done in a better way but changing split/merge/reorder to be a + // This could be done in a better way but changing split/merge to be a // TensorDomain focused operation, then we could simply do this process on // domains, instead of tensorviews. This would have the benefit that the // TensorView wouldn't change, so users pointers will remain valid. The other From a21b41ada02ffb77090b1dd5c8af073fd08e71e1 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 23 May 2020 14:58:06 -0400 Subject: [PATCH 3/8] Disconnect the world, change merge and split to operate on IterDomain. --- torch/csrc/jit/codegen/cuda/index_compute.cpp | 763 ++++---- torch/csrc/jit/codegen/cuda/index_compute.h | 56 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 47 +- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 13 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 40 +- torch/csrc/jit/codegen/cuda/mutator.cpp | 19 +- .../csrc/jit/codegen/cuda/transform_iter.cpp | 1608 +++++++++-------- torch/csrc/jit/codegen/cuda/transform_iter.h | 73 +- .../jit/codegen/cuda/transform_replay.cpp | 706 ++++---- .../csrc/jit/codegen/cuda/transform_replay.h | 20 +- 10 files changed, 1725 insertions(+), 1620 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 631b9432b2817..4db45509439af 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1,375 +1,388 @@ -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { - -TensorDomain* IndexCompute::replayBackward(Split* split, TensorDomain*) { - int ax = split->axis(); - TORCH_INTERNAL_ASSERT( - ax >= 0 && (unsigned int)(ax + 1) < indices.size(), - "Hit an invalid Split transformation during IndexCompute, axis is not within bounds."); - indices[ax] = add(mul(indices[ax], split->factor()), indices[ax + 1]); - indices.erase(indices.begin() + ax + 1); - return split->in(); -} - -TensorDomain* IndexCompute::replayBackward(Merge* merge, TensorDomain*) { - int ax = merge->axis(); - TORCH_INTERNAL_ASSERT( - ax >= 0 && (unsigned int)ax < indices.size(), - "Hit an invalid MERGE transformation during IndexCompute, axis is not within bounds."); - - Val* I = merge->in()->axis(ax + 1)->extent(); - Val* ind = indices[ax]; - indices[ax] = div(ind, I); - indices.insert(indices.begin() + ax + 1, mod(ind, I)); - return merge->in(); -} - -TensorDomain* IndexCompute::runBackward(std::vector history) { - TensorDomain* running_td = nullptr; - for (auto it = history.rbegin(); it != history.rend(); it++) - running_td = TransformIter::replayBackward(*it, running_td); - - return running_td; -} - -IndexCompute::IndexCompute(TensorDomain* td, std::vector _indices) - : indices(std::move(_indices)) { - bool exclude_reduction = td->nDims() > indices.size(); - - TORCH_INTERNAL_ASSERT( - td->noReductions().size() == indices.size() || - td->nDims() == indices.size(), - "For IndexCompute the number of axes should match the number of dimensions" - " in the TensorDomain."); - - // If we need to ignore the reduction dimensions because a tensor is - // being consumed, not produced, then insert dummy dimensions in the - // indices for bookkeeping while replaying split/merge operations. - if (exclude_reduction) - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) - if (td->axis(i)->isReduction()) - indices.insert(indices.begin() + i, new Int(-1)); - - TORCH_INTERNAL_ASSERT( - indices.size() == td->nDims(), - "Attempted to modify indices for IndexCompute, but didn't work."); - - // Run the split/merge operations backwards. This will - // Modify std::vector indices so it can be used to index - // the root TensorDomain which should now match the physical axes. - TensorDomain* root = TransformIter::getRoot(td); - auto history = TransformIter::getHistory(td); - if (exclude_reduction && td->hasRFactor()) { - root = TransformIter::getRFactorRoot(td); - auto rfactor_history = TransformIter::getHistory(root); - history.erase(history.begin(), history.begin() + rfactor_history.size()); - } - - runBackward(history); - - TORCH_INTERNAL_ASSERT( - root->nDims() == indices.size(), - "Error during IndexCompute. The number of indices generated" - " after running the transformations backwards should match" - " the number of dimensions of the root TensorDomain."); -} - -std::vector IndexCompute::get( - TensorDomain* td, - const std::vector& _indices) { - IndexCompute ic(td, _indices); - return ic.indices; -} - -TensorIndex* Index::getGlobalProducerIndex( - TensorView* producer, - TensorView* consumer, - const std::vector& loops) { - // This replay will ignore reduction dimensions on the producer - auto pind = - TransformReplay::replayPasC(producer->domain(), consumer->domain(), -1); - - TORCH_INTERNAL_ASSERT( - loops.size() == consumer->nDims(), - "Dimensionality error in code generator while computing tensor indexes."); - - std::vector loops_adjusted; - size_t it_c = 0, it_p = 0; - while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) { - if (consumer->axis(it_c)->isBroadcast() && - !pind->noReductions()[it_p]->isBroadcast()) { - it_c++; - } else { - loops_adjusted.push_back(loops[it_c]); - it_c++; - it_p++; - } - } - - TORCH_INTERNAL_ASSERT( - loops_adjusted.size() == pind->noReductions().size(), - "Dimensionality error in code generator while computing tensor indexes."); - - std::vector indices(loops_adjusted.size()); - std::transform( - loops_adjusted.begin(), - loops_adjusted.end(), - indices.begin(), - [](ForLoop* fl) { return fl->index(); }); - std::vector computed_inds = IndexCompute::get(pind, indices); - - auto root_domain = producer->getRootDomain(); - - TORCH_INTERNAL_ASSERT( - computed_inds.size() == root_domain->nDims(), - "Dimensionality error in code generator while computing indexing."); - - for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { - if (root_domain->axis(i)->isReduction() || - root_domain->axis(i)->isBroadcast()) - computed_inds.erase(computed_inds.begin() + i); - } - - std::vector strided_inds; - for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { - std::stringstream ss; - ss << "T" << producer->name() << ".stride[" << i << "]"; - strided_inds.push_back( - mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); - } - - // Probably shouldn't ever hit this - if (strided_inds.size() == 0) - strided_inds.push_back(new Int(0)); - - return new TensorIndex(producer, strided_inds); -} - -// Producer index for either shared or local memory -TensorIndex* Index::getProducerIndex_impl( - TensorView* producer, - TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT( - loops.size() == consumer->nDims(), - "Dimensionality error in code generator while computing tensor indexes."); - - std::vector loops_adjusted; - size_t it_c = 0, it_p = 0; - while (it_c < consumer->nDims() && it_p < producer->nDims()) { - if (consumer->axis(it_c)->isBroadcast() && - !producer->axis(it_p)->isBroadcast()) { - it_c++; - } else { - loops_adjusted.push_back(loops[it_c]); - it_c++; - it_p++; - } - } - - TORCH_INTERNAL_ASSERT( - loops_adjusted.size() == producer->domain()->noReductions().size(), - "Expected a tensor with ", - loops_adjusted.size(), - " dimensions but got one with ", - producer->nDims()); - - std::vector ranges(loops_adjusted.size()); - std::transform( - loops_adjusted.begin(), - loops_adjusted.end(), - ranges.begin(), - [](ForLoop* fl) { return fl->iter_domain(); }); - - std::vector indices(loops_adjusted.size()); - std::transform( - loops_adjusted.begin(), - loops_adjusted.end(), - indices.begin(), - [](ForLoop* fl) { - return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); - }); - - std::vector used_inds; - std::vector used_ranges; - bool unrolled = false; - for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) { - if (ranges[i]->parallel_method() == ParallelType::Unroll) - unrolled = true; - if (!unrolled && producer->hasComputeAt() && - i < producer->getThisComputeAtAxis()) - continue; - if (producer->getMemoryType() == MemoryType::Shared && - ranges[i]->isBlockDim()) - continue; - if (producer->getMemoryType() == MemoryType::Local && ranges[i]->isThread()) - continue; - if (ranges[i]->isBroadcast()) - continue; - - used_inds.push_back(indices[i]); - used_ranges.push_back(ranges[i]); - } - - for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { - Val* ind = used_inds[i]; - for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) - ind = mul(ind, used_ranges[j]->extent()); - used_inds[i] = ind; - } - if (used_inds.size() == 0) - used_inds.push_back(new Int(0)); - - return new TensorIndex(producer, used_inds); -} - -TensorIndex* Index::getGlobalConsumerIndex( - TensorView* consumer, - const std::vector& loops) { - // If we're initializing a reduction buffer, we won't have the reduction - // loops. If we're actually performing the reduction, we will. - - std::vector indices(loops.size()); - std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) { - return fl->index(); - }); - - std::vector computed_inds = - IndexCompute::get(consumer->domain(), indices); - - TensorDomain* root_dom = consumer->getRootDomain(); - TORCH_INTERNAL_ASSERT( - computed_inds.size() == root_dom->nDims(), - "Dimensionality error in code generator while computing indexing."); - - for (decltype(root_dom->nDims()) i{0}; i < root_dom->nDims(); i++) { - // Do this backwards so erase offset will be right - auto axis = root_dom->nDims() - i - 1; - if (root_dom->axis(axis)->isReduction() || root_dom->axis(i)->isBroadcast()) - computed_inds.erase(computed_inds.begin() + axis); - } - - std::vector strided_inds; - for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { - std::stringstream ss; - ss << "T" << consumer->name() << ".stride[" << i << "]"; - strided_inds.push_back( - mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); - } - - // Probably shouldn't ever hit this - if (strided_inds.size() == 0) - strided_inds.push_back(new Int(0)); - - return new TensorIndex(consumer, strided_inds); -} - -// Consumer index for either shared or local memory -TensorIndex* Index::getConsumerIndex_impl( - TensorView* consumer, - const std::vector& loops) { - // If we're initializing a reduction buffer, we won't have the reduction - // loops. If we're actually performing the reduction, we will. - - bool have_reduction_iters = loops.size() == consumer->nDims(); - - if (!have_reduction_iters) { - TORCH_INTERNAL_ASSERT( - // Init reduction space - loops.size() == consumer->domain()->noReductions().size(), - "Expected a tensor with ", - loops.size(), - " dimensions but got one with ", - consumer->domain()->noReductions().size()); - } else { - TORCH_INTERNAL_ASSERT( - // Calling the reduction op - loops.size() == consumer->nDims(), - "Expected a tensor with ", - loops.size(), - " dimensions but got one with ", - consumer->nDims()); - } - - std::vector ranges(loops.size()); - std::transform(loops.begin(), loops.end(), ranges.begin(), [](ForLoop* fl) { - return fl->iter_domain(); - }); - - std::vector indices(loops.size()); - std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) { - return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); - }); - - std::vector used_inds; - std::vector used_ranges; - bool unrolled = false; - for (decltype(loops.size()) i{0}; i < loops.size(); i++) { - if (have_reduction_iters && consumer->axis(i)->isReduction()) - continue; - if (ranges[i]->parallel_method() == ParallelType::Unroll) - unrolled = true; - if (!unrolled && consumer->hasComputeAt() && - i < consumer->getThisComputeAtAxis()) - continue; - if (consumer->getMemoryType() == MemoryType::Shared && - ranges[i]->isBlockDim()) - continue; - if (consumer->getMemoryType() == MemoryType::Local && ranges[i]->isThread()) - continue; - if (ranges[i]->isBroadcast()) - continue; - - used_inds.push_back(indices[i]); - used_ranges.push_back(ranges[i]); - } - - for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { - Val* ind = used_inds[i]; - for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) - ind = mul(ind, used_ranges[j]->extent()); - used_inds[i] = ind; - } - - if (used_inds.size() == 0) - used_inds.push_back(new Int(0)); - - return new TensorIndex(consumer, used_inds); -} - -// Producer is the inputs of an expression -TensorIndex* Index::getProducerIndex( - TensorView* producer, - TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT( - loops.size() == consumer->nDims() || - loops.size() == consumer->domain()->noReductions().size()); - - if (producer->getMemoryType() == MemoryType::Global) - return getGlobalProducerIndex(producer, consumer, loops); - return getProducerIndex_impl(producer, consumer, loops); -} - -// Consumer is the output of an expression -TensorIndex* Index::getConsumerIndex( - TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT( - loops.size() == consumer->nDims() || - loops.size() == consumer->domain()->noReductions().size()); - - if (consumer->getMemoryType() == MemoryType::Global) - return getGlobalConsumerIndex(consumer, loops); - return getConsumerIndex_impl(consumer, loops); -} - -} // namespace fuser -} // namespace jit -} // namespace torch +// #include +// #include +// #include +// #include + +// namespace torch { +// namespace jit { +// namespace fuser { + +// TensorDomain* IndexCompute::replayBackward(Split* split, TensorDomain*) { +// int ax = split->axis(); +// TORCH_INTERNAL_ASSERT( +// ax >= 0 && (unsigned int)(ax + 1) < indices.size(), +// "Hit an invalid Split transformation during IndexCompute, axis is not +// within bounds."); +// indices[ax] = add(mul(indices[ax], split->factor()), indices[ax + 1]); +// indices.erase(indices.begin() + ax + 1); +// return split->in(); +// } + +// TensorDomain* IndexCompute::replayBackward(Merge* merge, TensorDomain*) { +// int ax = merge->axis(); +// TORCH_INTERNAL_ASSERT( +// ax >= 0 && (unsigned int)ax < indices.size(), +// "Hit an invalid MERGE transformation during IndexCompute, axis is not +// within bounds."); + +// Val* I = merge->in()->axis(ax + 1)->extent(); +// Val* ind = indices[ax]; +// indices[ax] = div(ind, I); +// indices.insert(indices.begin() + ax + 1, mod(ind, I)); +// return merge->in(); +// } + +// TensorDomain* IndexCompute::runBackward(std::vector history) { +// TensorDomain* running_td = nullptr; +// for (auto it = history.rbegin(); it != history.rend(); it++) +// running_td = TransformIter::replayBackward(*it, running_td); + +// return running_td; +// } + +// IndexCompute::IndexCompute(TensorDomain* td, std::vector _indices) +// : indices(std::move(_indices)) { +// bool exclude_reduction = td->nDims() > indices.size(); + +// TORCH_INTERNAL_ASSERT( +// td->noReductions().size() == indices.size() || +// td->nDims() == indices.size(), +// "For IndexCompute the number of axes should match the number of +// dimensions" " in the TensorDomain."); + +// // If we need to ignore the reduction dimensions because a tensor is +// // being consumed, not produced, then insert dummy dimensions in the +// // indices for bookkeeping while replaying split/merge operations. +// if (exclude_reduction) +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) +// if (td->axis(i)->isReduction()) +// indices.insert(indices.begin() + i, new Int(-1)); + +// TORCH_INTERNAL_ASSERT( +// indices.size() == td->nDims(), +// "Attempted to modify indices for IndexCompute, but didn't work."); + +// // Run the split/merge operations backwards. This will +// // Modify std::vector indices so it can be used to index +// // the root TensorDomain which should now match the physical axes. +// TensorDomain* root = TransformIter::getRoot(td); +// auto history = TransformIter::getHistory(td); +// if (exclude_reduction && td->hasRFactor()) { +// root = TransformIter::getRFactorRoot(td); +// auto rfactor_history = TransformIter::getHistory(root); +// history.erase(history.begin(), history.begin() + rfactor_history.size()); +// } + +// runBackward(history); + +// TORCH_INTERNAL_ASSERT( +// root->nDims() == indices.size(), +// "Error during IndexCompute. The number of indices generated" +// " after running the transformations backwards should match" +// " the number of dimensions of the root TensorDomain."); +// } + +// std::vector IndexCompute::get( +// TensorDomain* td, +// const std::vector& _indices) { +// IndexCompute ic(td, _indices); +// return ic.indices; +// } + +// TensorIndex* Index::getGlobalProducerIndex( +// TensorView* producer, +// TensorView* consumer, +// const std::vector& loops) { +// // This replay will ignore reduction dimensions on the producer +// auto pind = +// TransformReplay::replayPasC(producer->domain(), consumer->domain(), +// -1); + +// TORCH_INTERNAL_ASSERT( +// loops.size() == consumer->nDims(), +// "Dimensionality error in code generator while computing tensor +// indexes."); + +// std::vector loops_adjusted; +// size_t it_c = 0, it_p = 0; +// while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) { +// if (consumer->axis(it_c)->isBroadcast() && +// !pind->noReductions()[it_p]->isBroadcast()) { +// it_c++; +// } else { +// loops_adjusted.push_back(loops[it_c]); +// it_c++; +// it_p++; +// } +// } + +// TORCH_INTERNAL_ASSERT( +// loops_adjusted.size() == pind->noReductions().size(), +// "Dimensionality error in code generator while computing tensor +// indexes."); + +// std::vector indices(loops_adjusted.size()); +// std::transform( +// loops_adjusted.begin(), +// loops_adjusted.end(), +// indices.begin(), +// [](ForLoop* fl) { return fl->index(); }); +// std::vector computed_inds = IndexCompute::get(pind, indices); + +// auto root_domain = producer->getRootDomain(); + +// TORCH_INTERNAL_ASSERT( +// computed_inds.size() == root_domain->nDims(), +// "Dimensionality error in code generator while computing indexing."); + +// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { +// if (root_domain->axis(i)->isReduction() || +// root_domain->axis(i)->isBroadcast()) +// computed_inds.erase(computed_inds.begin() + i); +// } + +// std::vector strided_inds; +// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { +// std::stringstream ss; +// ss << "T" << producer->name() << ".stride[" << i << "]"; +// strided_inds.push_back( +// mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); +// } + +// // Probably shouldn't ever hit this +// if (strided_inds.size() == 0) +// strided_inds.push_back(new Int(0)); + +// return new TensorIndex(producer, strided_inds); +// } + +// // Producer index for either shared or local memory +// TensorIndex* Index::getProducerIndex_impl( +// TensorView* producer, +// TensorView* consumer, +// const std::vector& loops) { +// TORCH_INTERNAL_ASSERT( +// loops.size() == consumer->nDims(), +// "Dimensionality error in code generator while computing tensor +// indexes."); + +// std::vector loops_adjusted; +// size_t it_c = 0, it_p = 0; +// while (it_c < consumer->nDims() && it_p < producer->nDims()) { +// if (consumer->axis(it_c)->isBroadcast() && +// !producer->axis(it_p)->isBroadcast()) { +// it_c++; +// } else { +// loops_adjusted.push_back(loops[it_c]); +// it_c++; +// it_p++; +// } +// } + +// TORCH_INTERNAL_ASSERT( +// loops_adjusted.size() == producer->domain()->noReductions().size(), +// "Expected a tensor with ", +// loops_adjusted.size(), +// " dimensions but got one with ", +// producer->nDims()); + +// std::vector ranges(loops_adjusted.size()); +// std::transform( +// loops_adjusted.begin(), +// loops_adjusted.end(), +// ranges.begin(), +// [](ForLoop* fl) { return fl->iter_domain(); }); + +// std::vector indices(loops_adjusted.size()); +// std::transform( +// loops_adjusted.begin(), +// loops_adjusted.end(), +// indices.begin(), +// [](ForLoop* fl) { +// return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); +// }); + +// std::vector used_inds; +// std::vector used_ranges; +// bool unrolled = false; +// for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) +// { +// if (ranges[i]->parallel_method() == ParallelType::Unroll) +// unrolled = true; +// if (!unrolled && producer->hasComputeAt() && +// i < producer->getThisComputeAtAxis()) +// continue; +// if (producer->getMemoryType() == MemoryType::Shared && +// ranges[i]->isBlockDim()) +// continue; +// if (producer->getMemoryType() == MemoryType::Local && +// ranges[i]->isThread()) +// continue; +// if (ranges[i]->isBroadcast()) +// continue; + +// used_inds.push_back(indices[i]); +// used_ranges.push_back(ranges[i]); +// } + +// for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { +// Val* ind = used_inds[i]; +// for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) +// ind = mul(ind, used_ranges[j]->extent()); +// used_inds[i] = ind; +// } +// if (used_inds.size() == 0) +// used_inds.push_back(new Int(0)); + +// return new TensorIndex(producer, used_inds); +// } + +// TensorIndex* Index::getGlobalConsumerIndex( +// TensorView* consumer, +// const std::vector& loops) { +// // If we're initializing a reduction buffer, we won't have the reduction +// // loops. If we're actually performing the reduction, we will. + +// std::vector indices(loops.size()); +// std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) +// { +// return fl->index(); +// }); + +// std::vector computed_inds = +// IndexCompute::get(consumer->domain(), indices); + +// TensorDomain* root_dom = consumer->getRootDomain(); +// TORCH_INTERNAL_ASSERT( +// computed_inds.size() == root_dom->nDims(), +// "Dimensionality error in code generator while computing indexing."); + +// for (decltype(root_dom->nDims()) i{0}; i < root_dom->nDims(); i++) { +// // Do this backwards so erase offset will be right +// auto axis = root_dom->nDims() - i - 1; +// if (root_dom->axis(axis)->isReduction() || +// root_dom->axis(i)->isBroadcast()) +// computed_inds.erase(computed_inds.begin() + axis); +// } + +// std::vector strided_inds; +// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { +// std::stringstream ss; +// ss << "T" << consumer->name() << ".stride[" << i << "]"; +// strided_inds.push_back( +// mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); +// } + +// // Probably shouldn't ever hit this +// if (strided_inds.size() == 0) +// strided_inds.push_back(new Int(0)); + +// return new TensorIndex(consumer, strided_inds); +// } + +// // Consumer index for either shared or local memory +// TensorIndex* Index::getConsumerIndex_impl( +// TensorView* consumer, +// const std::vector& loops) { +// // If we're initializing a reduction buffer, we won't have the reduction +// // loops. If we're actually performing the reduction, we will. + +// bool have_reduction_iters = loops.size() == consumer->nDims(); + +// if (!have_reduction_iters) { +// TORCH_INTERNAL_ASSERT( +// // Init reduction space +// loops.size() == consumer->domain()->noReductions().size(), +// "Expected a tensor with ", +// loops.size(), +// " dimensions but got one with ", +// consumer->domain()->noReductions().size()); +// } else { +// TORCH_INTERNAL_ASSERT( +// // Calling the reduction op +// loops.size() == consumer->nDims(), +// "Expected a tensor with ", +// loops.size(), +// " dimensions but got one with ", +// consumer->nDims()); +// } + +// std::vector ranges(loops.size()); +// std::transform(loops.begin(), loops.end(), ranges.begin(), [](ForLoop* fl) +// { +// return fl->iter_domain(); +// }); + +// std::vector indices(loops.size()); +// std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) +// { +// return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); +// }); + +// std::vector used_inds; +// std::vector used_ranges; +// bool unrolled = false; +// for (decltype(loops.size()) i{0}; i < loops.size(); i++) { +// if (have_reduction_iters && consumer->axis(i)->isReduction()) +// continue; +// if (ranges[i]->parallel_method() == ParallelType::Unroll) +// unrolled = true; +// if (!unrolled && consumer->hasComputeAt() && +// i < consumer->getThisComputeAtAxis()) +// continue; +// if (consumer->getMemoryType() == MemoryType::Shared && +// ranges[i]->isBlockDim()) +// continue; +// if (consumer->getMemoryType() == MemoryType::Local && +// ranges[i]->isThread()) +// continue; +// if (ranges[i]->isBroadcast()) +// continue; + +// used_inds.push_back(indices[i]); +// used_ranges.push_back(ranges[i]); +// } + +// for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { +// Val* ind = used_inds[i]; +// for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) +// ind = mul(ind, used_ranges[j]->extent()); +// used_inds[i] = ind; +// } + +// if (used_inds.size() == 0) +// used_inds.push_back(new Int(0)); + +// return new TensorIndex(consumer, used_inds); +// } + +// // Producer is the inputs of an expression +// TensorIndex* Index::getProducerIndex( +// TensorView* producer, +// TensorView* consumer, +// const std::vector& loops) { +// TORCH_INTERNAL_ASSERT( +// loops.size() == consumer->nDims() || +// loops.size() == consumer->domain()->noReductions().size()); + +// if (producer->getMemoryType() == MemoryType::Global) +// return getGlobalProducerIndex(producer, consumer, loops); +// return getProducerIndex_impl(producer, consumer, loops); +// } + +// // Consumer is the output of an expression +// TensorIndex* Index::getConsumerIndex( +// TensorView* consumer, +// const std::vector& loops) { +// TORCH_INTERNAL_ASSERT( +// loops.size() == consumer->nDims() || +// loops.size() == consumer->domain()->noReductions().size()); + +// if (consumer->getMemoryType() == MemoryType::Global) +// return getGlobalConsumerIndex(consumer, loops); +// return getConsumerIndex_impl(consumer, loops); +// } + +// } // namespace fuser +// } // namespace jit +// } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index c237c837aa125..cb11932f600fc 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -57,48 +57,54 @@ namespace fuser { struct IndexCompute : public TransformIter { private: - TensorDomain* replayBackward(Split*, TensorDomain*) override; - TensorDomain* replayBackward(Merge*, TensorDomain*) override; + // TensorDomain* replayBackward(Split*, TensorDomain*) override; + // TensorDomain* replayBackward(Merge*, TensorDomain*) override; - TensorDomain* runBackward(std::vector history); + // TensorDomain* runBackward(std::vector history); - // Otherwise warning on runBackward as it hides an overloaded virtual function - using TransformIter::runBackward; + // // Otherwise warning on runBackward as it hides an overloaded virtual + // function using TransformIter::runBackward; - IndexCompute(TensorDomain* td, std::vector _indices); - std::vector indices; + // IndexCompute(TensorDomain* td, std::vector _indices); + // std::vector indices; public: static std::vector get( TensorDomain* td, - const std::vector& _indices); + const std::vector& _indices) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } }; // Simple interface for IndexCompute struct Index : public TransformIter { private: - // Producer indexing if it's in shared or local memory - static TensorIndex* getProducerIndex_impl( - TensorView* producer, - TensorView* consumer, - const std::vector& loops); + // // Producer indexing if it's in shared or local memory + // static TensorIndex* getProducerIndex_impl( + // TensorView* producer, + // TensorView* consumer, + // const std::vector& loops); - // Consumer indexing if it's in shared or local memory - static TensorIndex* getConsumerIndex_impl( - TensorView* consumer, - const std::vector& loops); + // // Consumer indexing if it's in shared or local memory + // static TensorIndex* getConsumerIndex_impl( + // TensorView* consumer, + // const std::vector& loops); public: // Producer if it's in global memory static TensorIndex* getGlobalProducerIndex( TensorView* producer, TensorView* consumer, - const std::vector& loops); + const std::vector& loops) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Consumer indexing if it's in global memory static TensorIndex* getGlobalConsumerIndex( TensorView* consumer, - const std::vector& loops); + const std::vector& loops) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Indexing functions // Consumer = Producer @@ -107,15 +113,21 @@ struct Index : public TransformIter { static TensorIndex* getProducerIndex( TensorView* producer, TensorView* consumer, - const std::vector& loops); + const std::vector& loops) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Consumer index dispatch static TensorIndex* getConsumerIndex( TensorView* consumer, - const std::vector& loops); + const std::vector& loops) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Will run inds through back prop index computation for tv - static TensorIndex* manualBackprop(TensorView tv, std::vector inds); + static TensorIndex* manualBackprop(TensorView tv, std::vector inds) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index af4b6feae533c..3c87025fe479d 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -403,8 +403,7 @@ struct TORCH_CUDA_API TensorDomain : public Val { }; /* - * Representation for a split on IterDomain = axis in a TensorDomain, by factor - * = factor + * Representation a split on an IterDomain by "factor" * TODO: Implement split by nparts */ struct TORCH_CUDA_API Split : public Expr { @@ -416,17 +415,16 @@ struct TORCH_CUDA_API Split : public Expr { Split(Split&& other) = delete; Split& operator=(Split&& other) = delete; - Split(TensorDomain* _out, TensorDomain* _in, int _axis, Int* _factor); + Split(IterDomain* _outer, IterDomain* _inner, IterDomain* _in, Int* _factor); - TensorDomain* out() const noexcept { - return out_; + IterDomain* outer() const noexcept { + return outer_; } - TensorDomain* in() const noexcept { - return in_; + IterDomain* inner() const noexcept { + return inner_; } - - int axis() const noexcept { - return axis_; + IterDomain* in() const noexcept { + return in_; } Int* factor() const noexcept { return factor_; @@ -434,21 +432,22 @@ struct TORCH_CUDA_API Split : public Expr { bool sameAs(const Split* const other) const; private: - TensorDomain* const out_; - TensorDomain* const in_; - const int axis_; + IterDomain* const outer_; + IterDomain* const inner_; + IterDomain* const in_; Int* const factor_; }; /* - * Merge Iterdomain _axis in TensorDomain with the following IterDomain. Both - * IterDomains must be of the same iter or reduction type, as well as the same - * parallelization strategy if there is one. + * Merge the IterDomains outer and inner into one domain, outer and inner + * dictate which will be traversed first (inner). Both IterDomains must be of + * the same iter or reduction type, as well as the same parallelization strategy + * if there is one. * TODO: Should this be a unary op type? */ struct TORCH_CUDA_API Merge : public Expr { ~Merge() = default; - Merge(TensorDomain* _out, TensorDomain* _in, int _axis); + Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner); Merge(const Merge& other) = delete; Merge& operator=(const Merge& other) = delete; @@ -456,11 +455,14 @@ struct TORCH_CUDA_API Merge : public Expr { Merge(Merge&& other) = delete; Merge& operator=(Merge&& other) = delete; - TensorDomain* out() const noexcept { + IterDomain* out() const noexcept { return out_; } - TensorDomain* in() const noexcept { - return in_; + IterDomain* outer() const noexcept { + return outer_; + } + IterDomain* inner() const noexcept { + return inner_; } int axis() const noexcept { return axis_; @@ -469,8 +471,9 @@ struct TORCH_CUDA_API Merge : public Expr { bool sameAs(const Merge* const other) const; private: - TensorDomain* const out_; - TensorDomain* const in_; + IterDomain* const out_; + IterDomain* const outer_; + IterDomain* const inner_; int axis_; }; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index 2008a12fb3ff0..aab818d386537 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -504,14 +504,19 @@ void IRPrinter::handle(const Allocate* const a) { void IRPrinter::handle(const Split* const s) { os << "Split: "; handle(s->in()); - os << " axis " << s->axis() << " by factor " << s->factor() << " -> "; - handle(s->out()); + os << " by factor " << s->factor() << " -> "; + handle(s->outer()); + os << ", "; + handle(s->inner()); os << "\n"; } void IRPrinter::handle(const Merge* const m) { - os << "Merge: " << m->in() << " axis " << m->axis() - << " with the following -> "; + os << "Merge: "; + handle(m->outer()); + os << " and "; + handle(m->inner()); + os << " -> "; handle(m->out()); os << "\n"; } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 345be55285ba2..6f16dce6a52ae 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -366,7 +366,8 @@ TensorDomain* TensorDomain::split(int axis_, int factor) { } TensorDomain* split_td = new TensorDomain(new_domain); - new Split(split_td, this, axis_, fact); // For record keeping + { TORCH_INTERNAL_ASSERT(false, "NIY."); } + // new Split(split_td, this, axis_, fact); // For record keeping return split_td; } @@ -410,7 +411,8 @@ TensorDomain* TensorDomain::merge(int axis_) { } } TensorDomain* merged_td = new TensorDomain(new_domain); - new Merge(merged_td, this, axis_); // For record keeping + { TORCH_INTERNAL_ASSERT(false, "NIY."); } + // new Merge(merged_td, this, axis_); // For record keeping return merged_td; } @@ -580,40 +582,40 @@ TensorDomain* TensorDomain::rootDomain() { return TransformIter::getRoot(this); } -Split::Split(TensorDomain* _out, TensorDomain* _in, int _axis, Int* _factor) +Split::Split( + IterDomain* _outer, + IterDomain* _inner, + IterDomain* _in, + Int* _factor) : Expr(ExprType::Split), - out_{_out}, + outer_{_outer}, + inner_{_inner}, in_{_in}, - axis_{_axis}, factor_{_factor} { - TORCH_INTERNAL_ASSERT( - _axis >= 0 && _axis < _in->nDims(), - "Invalid split node, axis < 0 or >= in->nDims()."); - addOutput(_out); + addOutput(_outer); + addOutput(_inner); addInput(_in); this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } bool Split::sameAs(const Split* const other) const { return ( - out()->sameAs(other->out()) && in()->sameAs(other->in()) && - axis() == other->axis() && factor()->sameAs(other->factor())); + outer()->sameAs(other->outer()) && inner()->sameAs(other->inner()) && + in()->sameAs(other->in()) && factor()->sameAs(other->factor())); } -Merge::Merge(TensorDomain* _out, TensorDomain* _in, int _axis) - : Expr(ExprType::Merge), out_{_out}, in_{_in}, axis_{_axis} { - TORCH_INTERNAL_ASSERT( - _axis >= 0 && _axis < _in->nDims(), - "Invalid merge node, axis < 0 or >= in->nDims()."); +Merge::Merge(IterDomain* _out, IterDomain* _outer, IterDomain* _inner) + : Expr(ExprType::Merge), out_{_out}, outer_{_outer}, inner_{_inner} { addOutput(_out); - addInput(_in); + addInput(_outer); + addInput(_inner); this->name_ = FusionGuard::getCurFusion()->registerExpr(this); } bool Merge::sameAs(const Merge* const other) const { return ( - out()->sameAs(other->out()) && in()->sameAs(other->in()) && - axis() == other->axis()); + out()->sameAs(other->out()) && outer()->sameAs(other->outer()) && + inner()->sameAs(other->inner())); } ForLoop::ForLoop( diff --git a/torch/csrc/jit/codegen/cuda/mutator.cpp b/torch/csrc/jit/codegen/cuda/mutator.cpp index 9b792954cf737..f6cf432d01139 100644 --- a/torch/csrc/jit/codegen/cuda/mutator.cpp +++ b/torch/csrc/jit/codegen/cuda/mutator.cpp @@ -135,25 +135,28 @@ Statement* OptOutMutator::mutate(Allocate* a) { } Statement* OptOutMutator::mutate(Split* s) { - TensorDomain* o = static_cast(mutateAsVal(s->out())); - TensorDomain* i = static_cast(mutateAsVal(s->in())); + IterDomain* ot = static_cast(mutateAsVal(s->outer())); + IterDomain* inr = static_cast(mutateAsVal(s->inner())); + IterDomain* in = static_cast(mutateAsVal(s->in())); Int* fact = static_cast(mutateAsVal(s->factor())); - if (o->sameAs(s->out()) && i->sameAs(s->in()) && fact->sameAs(s->factor())) + if (ot->sameAs(s->outer()) && inr->sameAs(s->inner()) && + in->sameAs(s->in()) && fact->sameAs(s->factor())) return s; FusionGuard::getCurFusion()->removeExpr(s); - return new Split(o, i, s->axis(), fact); + return new Split(ot, inr, in, fact); } Statement* OptOutMutator::mutate(Merge* m) { - TensorDomain* o = static_cast(mutateAsVal(m->out())); - TensorDomain* i = static_cast(mutateAsVal(m->in())); + IterDomain* ot = static_cast(mutateAsVal(m->out())); + IterDomain* otr = static_cast(mutateAsVal(m->outer())); + IterDomain* in = static_cast(mutateAsVal(m->inner())); - if (o->sameAs(m->out()) && i->sameAs(m->in())) + if (ot->sameAs(m->out()) && otr->sameAs(m->outer()) && in->sameAs(m->inner())) return m; FusionGuard::getCurFusion()->removeExpr(m); - return new Merge(o, i, m->axis()); + return new Merge(ot, otr, in); } Statement* OptOutMutator::mutate(UnaryOp* uop) { diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index e30d6321ae294..6d00a71bfed32 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -1,796 +1,812 @@ -#include -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { - -TensorDomain* TransformIter::replayBackward(Split* split, TensorDomain* td) { - return split->in(); -} - -TensorDomain* TransformIter::replayBackward(Merge* merge, TensorDomain* td) { - return merge->in(); -} - -TensorDomain* TransformIter::replayBackward(Expr* expr, TensorDomain* td) { - TORCH_INTERNAL_ASSERT( - expr->isExpr(), - "Dispatch in transform iteration is expecting Exprs only."); - switch (*(expr->getExprType())) { - case (ExprType::Split): - return replayBackward(static_cast(expr), td); - case (ExprType::Merge): - return replayBackward(static_cast(expr), td); - default: - TORCH_INTERNAL_ASSERT( - false, "Could not detect expr type in replayBackward."); - } -} - -std::vector TransformIter::getHistory(TensorDomain* td) { - std::vector ops; - TensorDomain* root = td; // backward running td - Fusion* fusion = FusionGuard::getCurFusion(); - - // Get my origin - Expr* orig = fusion->origin(root); - std::set visited_exprs; - - // If I'm not back to the original td - while (orig != nullptr) { - if (visited_exprs.find(orig) != visited_exprs.end()) - TORCH_INTERNAL_ASSERT( - false, - "TransformReplay::runBackward is not traversing a correct history."); - ops.push_back(orig); - visited_exprs.emplace(orig); - TensorDomain* previous_td = nullptr; - // Check inputs of this operation, make sure there isn't more than one TD - // I can only record operations that only take this TD as an input. - for (Val* inp : orig->inputs()) - if (inp->getValType() == ValType::TensorDomain) { - if (previous_td != nullptr) - TORCH_INTERNAL_ASSERT( - false, - "TransformReplay::runBackward could not decifer transform history of a TensorDomain."); - - // Traverse back - root = static_cast(inp); - orig = fusion->origin(root); - } - } - return std::vector(ops.rbegin(), ops.rend()); -} - -TensorDomain* TransformIter::runBackward(TensorDomain* td) { - std::vector ops = getHistory(td); - - // We want to iterate backwards, reverse history. - ops = std::vector(ops.rbegin(), ops.rend()); - - TensorDomain* running_td = td; - for (Expr* op : ops) - running_td = replayBackward(op, running_td); - - return running_td; -} - -TensorDomain* TransformIter::replay(Split* expr, TensorDomain* td) { - return td->split( - expr->axis(), static_cast(expr->factor())->value().value()); -} - -TensorDomain* TransformIter::replay(Merge* expr, TensorDomain* td) { - return td->merge(expr->axis()); -} - -TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) { - TORCH_INTERNAL_ASSERT(expr->isExpr()); - switch (*(expr->getExprType())) { - case (ExprType::Split): - return replay(static_cast(expr), td); - case (ExprType::Merge): - return replay(static_cast(expr), td); - default: - TORCH_INTERNAL_ASSERT(false, "Could not detect expr type in replay."); - } -} - -TensorDomain* TransformIter::runReplay( - TensorDomain* td, - const std::vector& history) { - for (Expr* op : history) - td = TransformIter::replay(op, td); - return td; -} - -namespace { - -void validate_axis_map(int nDims, const std::vector& axis_map) { - TORCH_INTERNAL_ASSERT( - axis_map.size() == (unsigned int)nDims, - "Invalid axis map in replay transform. NDims doesn't match."); - - TORCH_INTERNAL_ASSERT( - !std::any_of( - axis_map.begin(), - axis_map.end(), - [nDims](int i) { return i < -1 || i >= nDims; }), - "Invalid axis map in replay transform, map goes outside domains of provided TensorDomain."); -} - -void validate_history_entry(Expr* expr, int nDims) { - TORCH_INTERNAL_ASSERT( - expr->input(0)->getValType().value() == ValType::TensorDomain && - static_cast(expr->input(0))->nDims() == - (unsigned int)nDims, - "Invalid history, or invalid axis_map in TransformIter."); -} - -struct Influence : public TransformIter { - private: - // BACKWARD INFLUENCE - - TensorDomain* replayBackward(Split* split, TensorDomain* td) override { - int axis = split->axis(); - - TORCH_INTERNAL_ASSERT( - (unsigned int)(axis + 1) < influence.size(), - "Error during replay backwards, td/influence size mismatch."); - influence[axis] = influence[axis] | influence[axis + 1]; - influence.erase(influence.begin() + axis + 1); - - return split->in(); - } - - TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { - int axis = merge->axis(); - TORCH_INTERNAL_ASSERT( - (unsigned int)axis < influence.size(), - "Error during replay backwards, td/influence size mismatch."); - influence.insert(influence.begin() + axis + 1, influence[axis]); - - return merge->in(); - } - - // FORWARD INFLUENCE - - TensorDomain* replay(Split* split, TensorDomain* td) override { - int axis = split->axis(); - TORCH_INTERNAL_ASSERT( - (unsigned int)axis < influence.size(), - "Error during replay, td/influence size mismatch."); - influence.insert(influence.begin() + axis + 1, influence[axis]); - return nullptr; - } - - TensorDomain* replay(Merge* merge, TensorDomain* td) override { - int axis = merge->axis(); - TORCH_INTERNAL_ASSERT( - axis >= 0 && (unsigned int)(axis + 1) < influence.size(), - "Error during replay, td/influence size mismatch."); - influence[axis] = influence[axis] | influence[axis + 1]; - influence.erase(influence.begin() + axis + 1); - return nullptr; - } - - // INTERFACE - - std::vector influence; - - Influence(std::vector td_influence) - : influence(std::move(td_influence)) {} - - using TransformIter::replayBackward; - using TransformIter::runReplay; - - public: - static std::vector computeBackward( - const std::vector& history, - const std::vector& td_influence) { - if (history.empty()) - return td_influence; - - Val* last_val = history[history.size() - 1]->output(0); - TORCH_INTERNAL_ASSERT( - last_val->getValType().value() == ValType::TensorDomain && - static_cast(last_val)->nDims() == - td_influence.size(), - "Tried to compute influence, but recieved an influence vector that does not match the expected size."); - - Influence inf(td_influence); - std::vector ops(history.rbegin(), history.rend()); - for (Expr* op : ops) - inf.replayBackward(op, nullptr); - return inf.influence; - } - - static std::vector computeForward( - const std::vector& history, - const std::vector& td_influence) { - if (history.empty()) - return td_influence; - - TORCH_INTERNAL_ASSERT( - history[0]->input(0)->getValType().value() == ValType::TensorDomain && - static_cast(history[0]->input(0))->nDims() == - td_influence.size(), - "Tried to compute influence, but recieved an influence vector that does not match the expected size."); - Influence inf(td_influence); - inf.runReplay(nullptr, history); - return inf.influence; - } - -}; // struct Influence - -struct Replay : public TransformIter { - /* - * Replay functions, takes a TensorDomain and steps through the operations in - * "record" based on influence axes. Will also update influence and propagate - * it forward. - */ - TensorDomain* replay(Split* split, TensorDomain* td) override { - int saxis = split->axis(); - - TORCH_INTERNAL_ASSERT( - saxis >= 0 && (unsigned int)saxis < axis_map.size(), - "TransformReplay tried to modify an axis out of range, recieved ", - saxis, - " but this value should be >=0 and <", - axis_map.size()); - - // Axis relative to td - int axis = axis_map[saxis]; - - if (axis == -1) { - // don't modify path, we need an extra axis as there would have been one - // there, but we shouldn't modify it. - axis_map.insert(axis_map.begin() + saxis + 1, -1); - return td; - } - - // Move indices up as we now have an extra axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i + 1 : i; - }); - - // Insert new axis in map - axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); - - TORCH_INTERNAL_ASSERT( - split->factor()->isConst(), - "Cannot replay split as it's not based on a const value."); - td = td->split(axis, split->factor()->value().value()); - - return td; - } - - TensorDomain* replay(Merge* merge, TensorDomain* td) override { - int maxis = merge->axis(); - - TORCH_INTERNAL_ASSERT( - maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), - "TransformReplay tried to modify an axis out of range, recieved ", - maxis, - " but this value should be >= 0 and < axis_map.size()"); - - // Get axis relative to what we actually have in td. - int axis = axis_map[maxis]; - int axis_p_1 = axis_map[maxis + 1]; - // If either dim is not to be touch, set both not to be touched - axis = axis_p_1 == -1 ? -1 : axis; - axis_map[maxis] = axis; - - // Remove axis from axis_map as in original transformations it didn't exist - axis_map.erase(axis_map.begin() + maxis + 1); - - // Don't modify: - if (axis == -1) - return td; - - // Move indices down as we're removing an axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i - 1 : i; - }); - - return td->merge(axis); - } - - std::vector axis_map; - Replay(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} - - public: - // Replays history provided on td, axis_map is the mapping from td axes to - // those expected in history, if an axis shouldn't be transformed, it needs to - // be marked as -1 in the axis_map - static TensorDomain* replay( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - if (history.empty()) - return td; - - Replay r(axis_map); - return r.runReplay(td, history); - } - -}; // struct Replay - -struct ReplaySelf : public TransformIter { - /* - * Replay functions, takes a TensorDomain and steps through its own history - * and reapplies it based on influence axes. Will replay rfactor axes - * correctly as well. - */ - TensorDomain* replay(Split* split, TensorDomain* td) override { - int saxis = split->axis(); - - TORCH_INTERNAL_ASSERT( - saxis >= 0 && (unsigned int)saxis < axis_map.size(), - "TransformReplay tried to modify an axis out of range, recieved ", - saxis, - " but this value should be >=0 and <", - axis_map.size()); - - // Axis relative to td - int axis = axis_map[saxis]; - - if (axis == -1) { - // don't modify path, we need an extra axis as there would have been one - // there, but we shouldn't modify it. - axis_map.insert(axis_map.begin() + saxis + 1, -1); - return td; - } - - // Move indices up as we now have an extra axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i + 1 : i; - }); - - // Insert new axis in map - axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); - - TORCH_INTERNAL_ASSERT( - split->factor()->isConst(), - "Cannot replay split as it's not based on a const value."); - - // Create new domain reflecting split - std::vector new_domain; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if ((int)i == axis) { - // We want to support cases where our root domain has changed sizes - // this happens in lowering when we replace sizes with runtime look ups - IterDomain* td_axis = td->axis(axis); - IterDomain* saxis_1 = split->out()->axis(saxis); - IterDomain* saxis_2 = split->out()->axis(saxis + 1); - // manually replay split domains using td extent, otherwise matching - // split axes params. - TORCH_CHECK( - td_axis->start()->isZeroInt(), - "Splitting IterDomains with starting values that aren't 0, is not supported at this time."); - - IterDomain* ido = new IterDomain( - new Int(0), - ceilDiv(td_axis->extent(), split->factor()), - saxis_1->parallel_method(), - saxis_1->isReduction(), - saxis_1->isRFactorProduct(), - saxis_1->isBroadcast()); - new_domain.push_back(ido); - - // inner loop IterDomain - IterDomain* idi = new IterDomain( - new Int(0), - split->factor(), - saxis_2->parallel_method(), - saxis_2->isReduction(), - saxis_2->isRFactorProduct(), - saxis_1->isBroadcast()); - new_domain.push_back(idi); - } else { - // Add in all other axes, these may not match the input td to the split. - new_domain.push_back(td->axis(i)); - } - } - - TensorDomain* replayed = new TensorDomain(new_domain); - new Split(replayed, td, axis, split->factor()); - return replayed; - } - - TensorDomain* replay(Merge* merge, TensorDomain* td) override { - int maxis = merge->axis(); - - TORCH_INTERNAL_ASSERT( - maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), - "TransformReplay tried to modify an axis out of range, recieved ", - maxis, - " but this value should be >= 0 and < axis_map.size()"); - - // Get axis relative to what we actually have in td. - int axis = axis_map[maxis]; - int axis_p_1 = axis_map[maxis + 1]; - // If either dim is not to be touch, set both not to be touched - axis = axis_p_1 == -1 ? -1 : axis; - axis_map[maxis] = axis; - - // Remove axis from axis_map as in original transformations it didn't exist - axis_map.erase(axis_map.begin() + maxis + 1); - - // Don't modify: - if (axis == -1) - return td; - - // Move indices down as we're removing an axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i - 1 : i; - }); - - // Create new domain reflecting post-merge - std::vector new_domain; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if ((int)i == axis) { - // We want to support cases where our root domain has changed sizes - // this happens in lowering when we replace sizes with runtime look ups - IterDomain* td_axis1 = td->axis(axis); - IterDomain* td_axis2 = td->axis(axis_p_1); - IterDomain* m_axis = merge->out()->axis(maxis); - - TORCH_INTERNAL_ASSERT( - td_axis1->start()->isZeroInt() && td_axis2->start()->isZeroInt(), - "Splitting IterDomains with starting values that aren't 0, is not supported at this time."); - - IterDomain* merged = new IterDomain( - new Int(0), - mul(td_axis1->extent(), td_axis2->extent()), - m_axis->parallel_method(), - m_axis->isReduction(), - m_axis->isRFactorProduct(), - m_axis->isBroadcast()); - new_domain.push_back(merged); - - } else if ((int)i != axis_p_1) { - // Add in all other axes, these may not match the input td to the split. - new_domain.push_back(td->axis(i)); - } - } - - TensorDomain* replayed = new TensorDomain(new_domain); - new Merge(replayed, td, axis); - return replayed; - } - - std::vector axis_map; - ReplaySelf(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} - - public: - // Replays history provided on td, axis_map is the mapping from td axes to - // those expected in history, if an axis shouldn't be transformed, it needs to - // be marked as -1 in the axis_map - static TensorDomain* replay( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - ReplaySelf r(axis_map); - return r.runReplay(TransformIter::getRoot(td), history); - } - -}; // struct ReplaySelf - -struct TransformBackward : public TransformIter { - private: - // axis_map goes from the transform position to the position in our modified - // td. - TensorDomain* replayBackward(Split* split, TensorDomain* td) override { - int saxis = split->axis(); - - TORCH_INTERNAL_ASSERT( - saxis >= 0 && (unsigned int)saxis < axis_map.size(), - "TransformBackward tried to modify an axis out of range, recieved ", - saxis, - " but this value should be >= 0 and < axis_map.size()"); - - // Get axis relative to what we actually have in td. - int axis = axis_map[saxis]; - int axis_p_1 = axis_map[saxis + 1]; - // If either dim is not to be touch, set both not to be touched - axis = axis_p_1 == -1 ? -1 : axis; - axis_map[saxis] = axis; - - // Remove axis from axis_map as in original transformations it didn't exist - axis_map.erase(axis_map.begin() + saxis + 1); - - // Don't modify: - if (axis == -1) - return td; - - // Move indices down as previously we didn't have the split axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i - 1 : i; - }); - - // Create new domain reflecting pre-split - std::vector new_domain; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if ((int)i == axis) { - IterDomain* orig_axis = split->in()->axis(saxis); - // Insert pre-split axis, make sure isReduction matches what is expected - new_domain.push_back(new IterDomain( - orig_axis->start(), - orig_axis->extent(), - orig_axis->parallel_method(), - td->axis(axis)->isReduction(), - td->axis(axis)->isRFactorProduct(), - td->axis(axis)->isBroadcast())); - } else if ((int)i != axis_p_1) { - // Add in all other axes, these may not match the input td to the split. - new_domain.push_back(td->axis(i)); - } - } - - TensorDomain* replayed_inp = new TensorDomain(new_domain); - new Split(td, replayed_inp, axis, split->factor()); - return replayed_inp; - } - - TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { - /* - * Remember axis_map goes from merge information -> how it's stored in td - * When we're done we want axis_map to match the returned td before or not - * before the merge depending on should_modify. - */ - - int maxis = merge->axis(); - - TORCH_INTERNAL_ASSERT( - maxis >= 0 && (unsigned int)maxis < axis_map.size(), - "TransformBackward tried to modify an axis out of range, recieved ", - maxis, - " but this value should be >=0 and <", - axis_map.size()); - - if (axis_map[maxis] == -1) { - // don't modify path, we need an extra axis as there was previously one - // there, but we shouldn't modify it. - axis_map.insert(axis_map.begin() + maxis + 1, -1); - return td; - } - - // Recreate the merge, axis is relative to the td - int axis = axis_map[maxis]; - // Move indices up as previously we had an extra axis - std::transform( - axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { - return i > axis ? i + 1 : i; - }); - - // Insert pre-merged axis back into map - axis_map.insert(axis_map.begin() + maxis + 1, axis_map[maxis] + 1); - - // Create new domain reflecting pre-merge - std::vector new_domain; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if ((int)i == axis) { - IterDomain* td_axis = td->axis(axis); - IterDomain* maxis_1 = merge->in()->axis(maxis); - IterDomain* maxis_2 = merge->in()->axis(maxis + 1); - new_domain.push_back(new IterDomain( - maxis_1->start(), - maxis_1->extent(), - ParallelType::Serial, - td_axis->isReduction(), - td_axis->isRFactorProduct(), - td_axis->isBroadcast())); - new_domain.push_back(new IterDomain( - maxis_2->start(), - maxis_2->extent(), - ParallelType::Serial, - td_axis->isReduction(), - td_axis->isRFactorProduct(), - td_axis->isBroadcast())); - } else { - // Add in all other axes, these may not match the input td to the split. - new_domain.push_back(td->axis(i)); - } - } - - TensorDomain* replayed_inp = new TensorDomain(new_domain); - new Merge(td, replayed_inp, axis); - return replayed_inp; - } - - // Entry for backward influence propagation on td following record, history - // should be present -> past as you go through the vector - TensorDomain* replayBackward( - TensorDomain* td, - const std::vector& history) { - TensorDomain* running_td = td; - - std::vector rev_history(history.rbegin(), history.rend()); - for (Expr* op : rev_history) - running_td = TransformIter::replayBackward(op, running_td); - return running_td; - } - - std::vector axis_map; - - TransformBackward(std::vector _axis_map) - : axis_map(std::move(_axis_map)){}; - - public: - static TensorDomain* replay( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - TransformBackward tb(axis_map); - return tb.replayBackward(td, history); - } -}; - -struct RFactorRoot : public TransformIter { - bool found_non_rfactor_op = false; - - TensorDomain* replay(Split* split, TensorDomain*) final { - if (!split->in()->axis(split->axis())->isRFactorProduct()) - found_non_rfactor_op = true; - return split->out(); - } - - TensorDomain* replay(Merge* merge, TensorDomain*) final { - if (!merge->in()->axis(merge->axis())->isRFactorProduct()) - found_non_rfactor_op = true; - return merge->out(); - } - - // Replay forward until we hit an operation that doesn't involve an rfactor - // axis - TensorDomain* runReplay(TensorDomain*, const std::vector& history) - final { - TORCH_INTERNAL_ASSERT( - !history.empty(), "No history provided to find rfactor root domain."); - - auto last_rfactor_op = history.begin(); - auto running_op = history.begin(); - - for (auto it = history.begin(); it != history.end(); it++) { - TransformIter::replay(*it, nullptr); - if (found_non_rfactor_op) - break; - running_op = it; - } - - // We need to make sure the rfactor root is ordered correctly. - bool found_valid_rfactor_root = false; - - Val* val; - - while (!found_valid_rfactor_root && last_rfactor_op != history.end()) { - // Try next val - val = (*last_rfactor_op++)->output(0); - TORCH_INTERNAL_ASSERT( - val->getValType().value() == ValType::TensorDomain, - "Invalid history to find rfactor root."); - - TensorDomain* td = static_cast(val); - bool found_rfactor_dim = false; - for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { - if (found_rfactor_dim) { - if (!td->axis(i)->isRFactorProduct()) - break; - } else { - if (td->axis(i)->isRFactorProduct()) - found_rfactor_dim = true; - } - if (i == td->nDims() - 1) - found_valid_rfactor_root = true; - } - } - TORCH_INTERNAL_ASSERT( - found_valid_rfactor_root, "Could not find a valid rfactor root."); - return static_cast(val); - } - - public: - static TensorDomain* get(TensorDomain* td) { - auto history = TransformIter::getHistory(td); - if (history.empty()) - return td; - RFactorRoot rfr; - return rfr.runReplay(nullptr, history); - } -}; - -} // namespace - -// API INTO TRANSFORM ITER - -std::vector TransformIter::getRootInfluence( - TensorDomain* td, - const std::vector& td_influence) { - return Influence::computeBackward( - TransformIter::getHistory(td), td_influence); -} - -std::vector TransformIter::replayBackwardInfluence( - const std::vector& history, - const std::vector& td_influence) { - return Influence::computeBackward(history, td_influence); -} - -std::vector TransformIter::replayInfluence( - const std::vector& history, - const std::vector& td_influence) { - if (history.empty()) - return td_influence; - - return Influence::computeForward(history, td_influence); -} - -TensorDomain* TransformIter::replay( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - if (history.empty()) - return td; - if (std::none_of( - axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) - return td; - - validate_history_entry(history[0], axis_map.size()); - return Replay::replay(td, history, axis_map); -} - -TensorDomain* TransformIter::replaySelf( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - if (std::none_of( - axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) - return TransformIter::getRoot(td); - - validate_axis_map(TransformIter::getRoot(td)->nDims(), axis_map); - return ReplaySelf::replay(td, history, axis_map); -} - -TensorDomain* TransformIter::replayBackward( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - if (history.empty()) - return td; - if (std::none_of( - axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) - return td; - - TORCH_INTERNAL_ASSERT( - history[history.size() - 1]->output(0)->getValType().value() == - ValType::TensorDomain && - static_cast(history[history.size() - 1]->output(0)) - ->nDims() == axis_map.size(), - "Invalid history, or invalid axis_map in TransformIter."); - - return TransformBackward::replay(td, history, axis_map); -} - -TensorDomain* TransformIter::getRFactorRoot(TensorDomain* td) { - auto td_root = TransformIter::getRoot(td); - if (std::none_of( - td_root->domain().begin(), - td_root->domain().end(), - [](IterDomain* id) { return id->isRFactorProduct(); })) - return td_root; - - return RFactorRoot::get(td); -} - -} // namespace fuser -} // namespace jit -} // namespace torch \ No newline at end of file +// #include +// #include +// #include +// #include + +// namespace torch { +// namespace jit { +// namespace fuser { + +// TensorDomain* TransformIter::replayBackward(Split* split, TensorDomain* td) { +// return split->in(); +// } + +// TensorDomain* TransformIter::replayBackward(Merge* merge, TensorDomain* td) { +// return merge->in(); +// } + +// TensorDomain* TransformIter::replayBackward(Expr* expr, TensorDomain* td) { +// TORCH_INTERNAL_ASSERT( +// expr->isExpr(), +// "Dispatch in transform iteration is expecting Exprs only."); +// switch (*(expr->getExprType())) { +// case (ExprType::Split): +// return replayBackward(static_cast(expr), td); +// case (ExprType::Merge): +// return replayBackward(static_cast(expr), td); +// default: +// TORCH_INTERNAL_ASSERT( +// false, "Could not detect expr type in replayBackward."); +// } +// } + +// std::vector TransformIter::getHistory(TensorDomain* td) { +// std::vector ops; +// TensorDomain* root = td; // backward running td +// Fusion* fusion = FusionGuard::getCurFusion(); + +// // Get my origin +// Expr* orig = fusion->origin(root); +// std::set visited_exprs; + +// // If I'm not back to the original td +// while (orig != nullptr) { +// if (visited_exprs.find(orig) != visited_exprs.end()) +// TORCH_INTERNAL_ASSERT( +// false, +// "TransformReplay::runBackward is not traversing a correct +// history."); +// ops.push_back(orig); +// visited_exprs.emplace(orig); +// TensorDomain* previous_td = nullptr; +// // Check inputs of this operation, make sure there isn't more than one TD +// // I can only record operations that only take this TD as an input. +// for (Val* inp : orig->inputs()) +// if (inp->getValType() == ValType::TensorDomain) { +// if (previous_td != nullptr) +// TORCH_INTERNAL_ASSERT( +// false, +// "TransformReplay::runBackward could not decifer transform +// history of a TensorDomain."); + +// // Traverse back +// root = static_cast(inp); +// orig = fusion->origin(root); +// } +// } +// return std::vector(ops.rbegin(), ops.rend()); +// } + +// TensorDomain* TransformIter::runBackward(TensorDomain* td) { +// std::vector ops = getHistory(td); + +// // We want to iterate backwards, reverse history. +// ops = std::vector(ops.rbegin(), ops.rend()); + +// TensorDomain* running_td = td; +// for (Expr* op : ops) +// running_td = replayBackward(op, running_td); + +// return running_td; +// } + +// TensorDomain* TransformIter::replay(Split* expr, TensorDomain* td) { +// return td->split( +// expr->axis(), static_cast(expr->factor())->value().value()); +// } + +// TensorDomain* TransformIter::replay(Merge* expr, TensorDomain* td) { +// return td->merge(expr->axis()); +// } + +// TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) { +// TORCH_INTERNAL_ASSERT(expr->isExpr()); +// switch (*(expr->getExprType())) { +// case (ExprType::Split): +// return replay(static_cast(expr), td); +// case (ExprType::Merge): +// return replay(static_cast(expr), td); +// default: +// TORCH_INTERNAL_ASSERT(false, "Could not detect expr type in replay."); +// } +// } + +// TensorDomain* TransformIter::runReplay( +// TensorDomain* td, +// const std::vector& history) { +// for (Expr* op : history) +// td = TransformIter::replay(op, td); +// return td; +// } + +// namespace { + +// void validate_axis_map(int nDims, const std::vector& axis_map) { +// TORCH_INTERNAL_ASSERT( +// axis_map.size() == (unsigned int)nDims, +// "Invalid axis map in replay transform. NDims doesn't match."); + +// TORCH_INTERNAL_ASSERT( +// !std::any_of( +// axis_map.begin(), +// axis_map.end(), +// [nDims](int i) { return i < -1 || i >= nDims; }), +// "Invalid axis map in replay transform, map goes outside domains of +// provided TensorDomain."); +// } + +// void validate_history_entry(Expr* expr, int nDims) { +// TORCH_INTERNAL_ASSERT( +// expr->input(0)->getValType().value() == ValType::TensorDomain && +// static_cast(expr->input(0))->nDims() == +// (unsigned int)nDims, +// "Invalid history, or invalid axis_map in TransformIter."); +// } + +// struct Influence : public TransformIter { +// private: +// // BACKWARD INFLUENCE + +// TensorDomain* replayBackward(Split* split, TensorDomain* td) override { +// int axis = split->axis(); + +// TORCH_INTERNAL_ASSERT( +// (unsigned int)(axis + 1) < influence.size(), +// "Error during replay backwards, td/influence size mismatch."); +// influence[axis] = influence[axis] | influence[axis + 1]; +// influence.erase(influence.begin() + axis + 1); + +// return split->in(); +// } + +// TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { +// int axis = merge->axis(); +// TORCH_INTERNAL_ASSERT( +// (unsigned int)axis < influence.size(), +// "Error during replay backwards, td/influence size mismatch."); +// influence.insert(influence.begin() + axis + 1, influence[axis]); + +// return merge->in(); +// } + +// // FORWARD INFLUENCE + +// TensorDomain* replay(Split* split, TensorDomain* td) override { +// int axis = split->axis(); +// TORCH_INTERNAL_ASSERT( +// (unsigned int)axis < influence.size(), +// "Error during replay, td/influence size mismatch."); +// influence.insert(influence.begin() + axis + 1, influence[axis]); +// return nullptr; +// } + +// TensorDomain* replay(Merge* merge, TensorDomain* td) override { +// int axis = merge->axis(); +// TORCH_INTERNAL_ASSERT( +// axis >= 0 && (unsigned int)(axis + 1) < influence.size(), +// "Error during replay, td/influence size mismatch."); +// influence[axis] = influence[axis] | influence[axis + 1]; +// influence.erase(influence.begin() + axis + 1); +// return nullptr; +// } + +// // INTERFACE + +// std::vector influence; + +// Influence(std::vector td_influence) +// : influence(std::move(td_influence)) {} + +// using TransformIter::replayBackward; +// using TransformIter::runReplay; + +// public: +// static std::vector computeBackward( +// const std::vector& history, +// const std::vector& td_influence) { +// if (history.empty()) +// return td_influence; + +// Val* last_val = history[history.size() - 1]->output(0); +// TORCH_INTERNAL_ASSERT( +// last_val->getValType().value() == ValType::TensorDomain && +// static_cast(last_val)->nDims() == +// td_influence.size(), +// "Tried to compute influence, but recieved an influence vector that +// does not match the expected size."); + +// Influence inf(td_influence); +// std::vector ops(history.rbegin(), history.rend()); +// for (Expr* op : ops) +// inf.replayBackward(op, nullptr); +// return inf.influence; +// } + +// static std::vector computeForward( +// const std::vector& history, +// const std::vector& td_influence) { +// if (history.empty()) +// return td_influence; + +// TORCH_INTERNAL_ASSERT( +// history[0]->input(0)->getValType().value() == ValType::TensorDomain +// && +// static_cast(history[0]->input(0))->nDims() == +// td_influence.size(), +// "Tried to compute influence, but recieved an influence vector that +// does not match the expected size."); +// Influence inf(td_influence); +// inf.runReplay(nullptr, history); +// return inf.influence; +// } + +// }; // struct Influence + +// struct Replay : public TransformIter { +// /* +// * Replay functions, takes a TensorDomain and steps through the operations +// in +// * "record" based on influence axes. Will also update influence and +// propagate +// * it forward. +// */ +// TensorDomain* replay(Split* split, TensorDomain* td) override { +// int saxis = split->axis(); + +// TORCH_INTERNAL_ASSERT( +// saxis >= 0 && (unsigned int)saxis < axis_map.size(), +// "TransformReplay tried to modify an axis out of range, recieved ", +// saxis, +// " but this value should be >=0 and <", +// axis_map.size()); + +// // Axis relative to td +// int axis = axis_map[saxis]; + +// if (axis == -1) { +// // don't modify path, we need an extra axis as there would have been +// one +// // there, but we shouldn't modify it. +// axis_map.insert(axis_map.begin() + saxis + 1, -1); +// return td; +// } + +// // Move indices up as we now have an extra axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i + 1 : i; +// }); + +// // Insert new axis in map +// axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); + +// TORCH_INTERNAL_ASSERT( +// split->factor()->isConst(), +// "Cannot replay split as it's not based on a const value."); +// td = td->split(axis, split->factor()->value().value()); + +// return td; +// } + +// TensorDomain* replay(Merge* merge, TensorDomain* td) override { +// int maxis = merge->axis(); + +// TORCH_INTERNAL_ASSERT( +// maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), +// "TransformReplay tried to modify an axis out of range, recieved ", +// maxis, +// " but this value should be >= 0 and < axis_map.size()"); + +// // Get axis relative to what we actually have in td. +// int axis = axis_map[maxis]; +// int axis_p_1 = axis_map[maxis + 1]; +// // If either dim is not to be touch, set both not to be touched +// axis = axis_p_1 == -1 ? -1 : axis; +// axis_map[maxis] = axis; + +// // Remove axis from axis_map as in original transformations it didn't +// exist axis_map.erase(axis_map.begin() + maxis + 1); + +// // Don't modify: +// if (axis == -1) +// return td; + +// // Move indices down as we're removing an axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i - 1 : i; +// }); + +// return td->merge(axis); +// } + +// std::vector axis_map; +// Replay(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} + +// public: +// // Replays history provided on td, axis_map is the mapping from td axes to +// // those expected in history, if an axis shouldn't be transformed, it needs +// to +// // be marked as -1 in the axis_map +// static TensorDomain* replay( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// if (history.empty()) +// return td; + +// Replay r(axis_map); +// return r.runReplay(td, history); +// } + +// }; // struct Replay + +// struct ReplaySelf : public TransformIter { +// /* +// * Replay functions, takes a TensorDomain and steps through its own history +// * and reapplies it based on influence axes. Will replay rfactor axes +// * correctly as well. +// */ +// TensorDomain* replay(Split* split, TensorDomain* td) override { +// int saxis = split->axis(); + +// TORCH_INTERNAL_ASSERT( +// saxis >= 0 && (unsigned int)saxis < axis_map.size(), +// "TransformReplay tried to modify an axis out of range, recieved ", +// saxis, +// " but this value should be >=0 and <", +// axis_map.size()); + +// // Axis relative to td +// int axis = axis_map[saxis]; + +// if (axis == -1) { +// // don't modify path, we need an extra axis as there would have been +// one +// // there, but we shouldn't modify it. +// axis_map.insert(axis_map.begin() + saxis + 1, -1); +// return td; +// } + +// // Move indices up as we now have an extra axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i + 1 : i; +// }); + +// // Insert new axis in map +// axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); + +// TORCH_INTERNAL_ASSERT( +// split->factor()->isConst(), +// "Cannot replay split as it's not based on a const value."); + +// // Create new domain reflecting split +// std::vector new_domain; +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { +// if ((int)i == axis) { +// // We want to support cases where our root domain has changed sizes +// // this happens in lowering when we replace sizes with runtime look +// ups IterDomain* td_axis = td->axis(axis); IterDomain* saxis_1 = +// split->out()->axis(saxis); IterDomain* saxis_2 = +// split->out()->axis(saxis + 1); +// // manually replay split domains using td extent, otherwise matching +// // split axes params. +// TORCH_CHECK( +// td_axis->start()->isZeroInt(), +// "Splitting IterDomains with starting values that aren't 0, is not +// supported at this time."); + +// IterDomain* ido = new IterDomain( +// new Int(0), +// ceilDiv(td_axis->extent(), split->factor()), +// saxis_1->parallel_method(), +// saxis_1->isReduction(), +// saxis_1->isRFactorProduct(), +// saxis_1->isBroadcast()); +// new_domain.push_back(ido); + +// // inner loop IterDomain +// IterDomain* idi = new IterDomain( +// new Int(0), +// split->factor(), +// saxis_2->parallel_method(), +// saxis_2->isReduction(), +// saxis_2->isRFactorProduct(), +// saxis_1->isBroadcast()); +// new_domain.push_back(idi); +// } else { +// // Add in all other axes, these may not match the input td to the +// split. new_domain.push_back(td->axis(i)); +// } +// } + +// TensorDomain* replayed = new TensorDomain(new_domain); +// new Split(replayed, td, axis, split->factor()); +// return replayed; +// } + +// TensorDomain* replay(Merge* merge, TensorDomain* td) override { +// int maxis = merge->axis(); + +// TORCH_INTERNAL_ASSERT( +// maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), +// "TransformReplay tried to modify an axis out of range, recieved ", +// maxis, +// " but this value should be >= 0 and < axis_map.size()"); + +// // Get axis relative to what we actually have in td. +// int axis = axis_map[maxis]; +// int axis_p_1 = axis_map[maxis + 1]; +// // If either dim is not to be touch, set both not to be touched +// axis = axis_p_1 == -1 ? -1 : axis; +// axis_map[maxis] = axis; + +// // Remove axis from axis_map as in original transformations it didn't +// exist axis_map.erase(axis_map.begin() + maxis + 1); + +// // Don't modify: +// if (axis == -1) +// return td; + +// // Move indices down as we're removing an axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i - 1 : i; +// }); + +// // Create new domain reflecting post-merge +// std::vector new_domain; +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { +// if ((int)i == axis) { +// // We want to support cases where our root domain has changed sizes +// // this happens in lowering when we replace sizes with runtime look +// ups IterDomain* td_axis1 = td->axis(axis); IterDomain* td_axis2 = +// td->axis(axis_p_1); IterDomain* m_axis = merge->out()->axis(maxis); + +// TORCH_INTERNAL_ASSERT( +// td_axis1->start()->isZeroInt() && td_axis2->start()->isZeroInt(), +// "Splitting IterDomains with starting values that aren't 0, is not +// supported at this time."); + +// IterDomain* merged = new IterDomain( +// new Int(0), +// mul(td_axis1->extent(), td_axis2->extent()), +// m_axis->parallel_method(), +// m_axis->isReduction(), +// m_axis->isRFactorProduct(), +// m_axis->isBroadcast()); +// new_domain.push_back(merged); + +// } else if ((int)i != axis_p_1) { +// // Add in all other axes, these may not match the input td to the +// split. new_domain.push_back(td->axis(i)); +// } +// } + +// TensorDomain* replayed = new TensorDomain(new_domain); +// new Merge(replayed, td, axis); +// return replayed; +// } + +// std::vector axis_map; +// ReplaySelf(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} + +// public: +// // Replays history provided on td, axis_map is the mapping from td axes to +// // those expected in history, if an axis shouldn't be transformed, it needs +// to +// // be marked as -1 in the axis_map +// static TensorDomain* replay( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// ReplaySelf r(axis_map); +// return r.runReplay(TransformIter::getRoot(td), history); +// } + +// }; // struct ReplaySelf + +// struct TransformBackward : public TransformIter { +// private: +// // axis_map goes from the transform position to the position in our +// modified +// // td. +// TensorDomain* replayBackward(Split* split, TensorDomain* td) override { +// int saxis = split->axis(); + +// TORCH_INTERNAL_ASSERT( +// saxis >= 0 && (unsigned int)saxis < axis_map.size(), +// "TransformBackward tried to modify an axis out of range, recieved ", +// saxis, +// " but this value should be >= 0 and < axis_map.size()"); + +// // Get axis relative to what we actually have in td. +// int axis = axis_map[saxis]; +// int axis_p_1 = axis_map[saxis + 1]; +// // If either dim is not to be touch, set both not to be touched +// axis = axis_p_1 == -1 ? -1 : axis; +// axis_map[saxis] = axis; + +// // Remove axis from axis_map as in original transformations it didn't +// exist axis_map.erase(axis_map.begin() + saxis + 1); + +// // Don't modify: +// if (axis == -1) +// return td; + +// // Move indices down as previously we didn't have the split axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i - 1 : i; +// }); + +// // Create new domain reflecting pre-split +// std::vector new_domain; +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { +// if ((int)i == axis) { +// IterDomain* orig_axis = split->in()->axis(saxis); +// // Insert pre-split axis, make sure isReduction matches what is +// expected new_domain.push_back(new IterDomain( +// orig_axis->start(), +// orig_axis->extent(), +// orig_axis->parallel_method(), +// td->axis(axis)->isReduction(), +// td->axis(axis)->isRFactorProduct(), +// td->axis(axis)->isBroadcast())); +// } else if ((int)i != axis_p_1) { +// // Add in all other axes, these may not match the input td to the +// split. new_domain.push_back(td->axis(i)); +// } +// } + +// TensorDomain* replayed_inp = new TensorDomain(new_domain); +// new Split(td, replayed_inp, axis, split->factor()); +// return replayed_inp; +// } + +// TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { +// /* +// * Remember axis_map goes from merge information -> how it's stored in td +// * When we're done we want axis_map to match the returned td before or +// not +// * before the merge depending on should_modify. +// */ + +// int maxis = merge->axis(); + +// TORCH_INTERNAL_ASSERT( +// maxis >= 0 && (unsigned int)maxis < axis_map.size(), +// "TransformBackward tried to modify an axis out of range, recieved ", +// maxis, +// " but this value should be >=0 and <", +// axis_map.size()); + +// if (axis_map[maxis] == -1) { +// // don't modify path, we need an extra axis as there was previously one +// // there, but we shouldn't modify it. +// axis_map.insert(axis_map.begin() + maxis + 1, -1); +// return td; +// } + +// // Recreate the merge, axis is relative to the td +// int axis = axis_map[maxis]; +// // Move indices up as previously we had an extra axis +// std::transform( +// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { +// return i > axis ? i + 1 : i; +// }); + +// // Insert pre-merged axis back into map +// axis_map.insert(axis_map.begin() + maxis + 1, axis_map[maxis] + 1); + +// // Create new domain reflecting pre-merge +// std::vector new_domain; +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { +// if ((int)i == axis) { +// IterDomain* td_axis = td->axis(axis); +// IterDomain* maxis_1 = merge->in()->axis(maxis); +// IterDomain* maxis_2 = merge->in()->axis(maxis + 1); +// new_domain.push_back(new IterDomain( +// maxis_1->start(), +// maxis_1->extent(), +// ParallelType::Serial, +// td_axis->isReduction(), +// td_axis->isRFactorProduct(), +// td_axis->isBroadcast())); +// new_domain.push_back(new IterDomain( +// maxis_2->start(), +// maxis_2->extent(), +// ParallelType::Serial, +// td_axis->isReduction(), +// td_axis->isRFactorProduct(), +// td_axis->isBroadcast())); +// } else { +// // Add in all other axes, these may not match the input td to the +// split. new_domain.push_back(td->axis(i)); +// } +// } + +// TensorDomain* replayed_inp = new TensorDomain(new_domain); +// new Merge(td, replayed_inp, axis); +// return replayed_inp; +// } + +// // Entry for backward influence propagation on td following record, history +// // should be present -> past as you go through the vector +// TensorDomain* replayBackward( +// TensorDomain* td, +// const std::vector& history) { +// TensorDomain* running_td = td; + +// std::vector rev_history(history.rbegin(), history.rend()); +// for (Expr* op : rev_history) +// running_td = TransformIter::replayBackward(op, running_td); +// return running_td; +// } + +// std::vector axis_map; + +// TransformBackward(std::vector _axis_map) +// : axis_map(std::move(_axis_map)){}; + +// public: +// static TensorDomain* replay( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// TransformBackward tb(axis_map); +// return tb.replayBackward(td, history); +// } +// }; + +// struct RFactorRoot : public TransformIter { +// bool found_non_rfactor_op = false; + +// TensorDomain* replay(Split* split, TensorDomain*) final { +// if (!split->in()->axis(split->axis())->isRFactorProduct()) +// found_non_rfactor_op = true; +// return split->out(); +// } + +// TensorDomain* replay(Merge* merge, TensorDomain*) final { +// if (!merge->in()->axis(merge->axis())->isRFactorProduct()) +// found_non_rfactor_op = true; +// return merge->out(); +// } + +// // Replay forward until we hit an operation that doesn't involve an rfactor +// // axis +// TensorDomain* runReplay(TensorDomain*, const std::vector& history) +// final { +// TORCH_INTERNAL_ASSERT( +// !history.empty(), "No history provided to find rfactor root +// domain."); + +// auto last_rfactor_op = history.begin(); +// auto running_op = history.begin(); + +// for (auto it = history.begin(); it != history.end(); it++) { +// TransformIter::replay(*it, nullptr); +// if (found_non_rfactor_op) +// break; +// running_op = it; +// } + +// // We need to make sure the rfactor root is ordered correctly. +// bool found_valid_rfactor_root = false; + +// Val* val; + +// while (!found_valid_rfactor_root && last_rfactor_op != history.end()) { +// // Try next val +// val = (*last_rfactor_op++)->output(0); +// TORCH_INTERNAL_ASSERT( +// val->getValType().value() == ValType::TensorDomain, +// "Invalid history to find rfactor root."); + +// TensorDomain* td = static_cast(val); +// bool found_rfactor_dim = false; +// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { +// if (found_rfactor_dim) { +// if (!td->axis(i)->isRFactorProduct()) +// break; +// } else { +// if (td->axis(i)->isRFactorProduct()) +// found_rfactor_dim = true; +// } +// if (i == td->nDims() - 1) +// found_valid_rfactor_root = true; +// } +// } +// TORCH_INTERNAL_ASSERT( +// found_valid_rfactor_root, "Could not find a valid rfactor root."); +// return static_cast(val); +// } + +// public: +// static TensorDomain* get(TensorDomain* td) { +// auto history = TransformIter::getHistory(td); +// if (history.empty()) +// return td; +// RFactorRoot rfr; +// return rfr.runReplay(nullptr, history); +// } +// }; + +// } // namespace + +// // API INTO TRANSFORM ITER + +// std::vector TransformIter::getRootInfluence( +// TensorDomain* td, +// const std::vector& td_influence) { +// return Influence::computeBackward( +// TransformIter::getHistory(td), td_influence); +// } + +// std::vector TransformIter::replayBackwardInfluence( +// const std::vector& history, +// const std::vector& td_influence) { +// return Influence::computeBackward(history, td_influence); +// } + +// std::vector TransformIter::replayInfluence( +// const std::vector& history, +// const std::vector& td_influence) { +// if (history.empty()) +// return td_influence; + +// return Influence::computeForward(history, td_influence); +// } + +// TensorDomain* TransformIter::replay( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// if (history.empty()) +// return td; +// if (std::none_of( +// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) +// return td; + +// validate_history_entry(history[0], axis_map.size()); +// return Replay::replay(td, history, axis_map); +// } + +// TensorDomain* TransformIter::replaySelf( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// if (std::none_of( +// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) +// return TransformIter::getRoot(td); + +// validate_axis_map(TransformIter::getRoot(td)->nDims(), axis_map); +// return ReplaySelf::replay(td, history, axis_map); +// } + +// TensorDomain* TransformIter::replayBackward( +// TensorDomain* td, +// const std::vector& history, +// const std::vector& axis_map) { +// if (history.empty()) +// return td; +// if (std::none_of( +// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) +// return td; + +// TORCH_INTERNAL_ASSERT( +// history[history.size() - 1]->output(0)->getValType().value() == +// ValType::TensorDomain && +// static_cast(history[history.size() - 1]->output(0)) +// ->nDims() == axis_map.size(), +// "Invalid history, or invalid axis_map in TransformIter."); + +// return TransformBackward::replay(td, history, axis_map); +// } + +// TensorDomain* TransformIter::getRFactorRoot(TensorDomain* td) { +// auto td_root = TransformIter::getRoot(td); +// if (std::none_of( +// td_root->domain().begin(), +// td_root->domain().end(), +// [](IterDomain* id) { return id->isRFactorProduct(); })) +// return td_root; + +// return RFactorRoot::get(td); +// } + +// } // namespace fuser +// } // namespace jit +// } // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index 23d456e13b2df..aec625077d65e 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -20,79 +20,100 @@ namespace fuser { */ struct TORCH_CUDA_API TransformIter : public IterVisitor { protected: - virtual TensorDomain* replayBackward(Split*, TensorDomain*); - virtual TensorDomain* replayBackward(Merge*, TensorDomain*); + // virtual TensorDomain* replayBackward(Split*, TensorDomain*); + // virtual TensorDomain* replayBackward(Merge*, TensorDomain*); - // dispatch - TensorDomain* replayBackward(Expr*, TensorDomain*); + // // dispatch + // TensorDomain* replayBackward(Expr*, TensorDomain*); - // Iterates td's history starting with td, then origin(td), origin(origin(td)) - // etc. Returns root TensorDomain once it iterates through history. If - // generate_record=true It will record the history of td in record. Record is - // order operations root->td. - virtual TensorDomain* runBackward(TensorDomain*); + // // Iterates td's history starting with td, then origin(td), + // origin(origin(td)) + // // etc. Returns root TensorDomain once it iterates through history. If + // // generate_record=true It will record the history of td in record. + // Record is + // // order operations root->td. + // virtual TensorDomain* runBackward(TensorDomain*); - virtual TensorDomain* replay(Split*, TensorDomain*); - virtual TensorDomain* replay(Merge*, TensorDomain*); + // virtual TensorDomain* replay(Split*, TensorDomain*); + // virtual TensorDomain* replay(Merge*, TensorDomain*); - // dispatch - virtual TensorDomain* replay(Expr*, TensorDomain*); + // // dispatch + // virtual TensorDomain* replay(Expr*, TensorDomain*); - // Runs through operations in history and applies them to TD, runs exprs from - // begining to end - virtual TensorDomain* runReplay(TensorDomain*, const std::vector&); + // // Runs through operations in history and applies them to TD, runs exprs + // from + // // begining to end + // virtual TensorDomain* runReplay(TensorDomain*, const + // std::vector&); public: // Returns transformation exprs in forward order - static std::vector getHistory(TensorDomain*); + static std::vector getHistory(TensorDomain*) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // TODO: make td const static TensorDomain* getRoot(TensorDomain* td) { - TransformIter ti; - return ti.runBackward(td); + { TORCH_INTERNAL_ASSERT(false, "NIY."); } + // TransformIter ti; + // return ti.runBackward(td); } // Takes influence vector of bools, tracks them back to propagate true to root // axes that were modified into td axes matching marked influence vector. static std::vector getRootInfluence( TensorDomain* td, - const std::vector& influence); + const std::vector& influence) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } static std::vector replayBackwardInfluence( const std::vector& history, - const std::vector& td_influence); + const std::vector& td_influence) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Runs through history, applying only on influence to track how modifications // would influence the original axes. static std::vector replayInfluence( const std::vector& history, - const std::vector& td_influence); + const std::vector& td_influence) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Goes through history and applies it to td, with the axis_map provided. // Axis_map entries of -1 mean those axes won't be modified static TensorDomain* replay( TensorDomain* td, const std::vector& history, - const std::vector& axis_map); + const std::vector& axis_map) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Takes td, and replays history backwards on it to create a new root tensor // domain using axis_map. Entries in axis_map == -1 will not be modified static TensorDomain* replayBackward( TensorDomain* td, const std::vector& history, - const std::vector& axis_map); + const std::vector& axis_map) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } static TensorDomain* replaySelf( TensorDomain* td, const std::vector& history, - const std::vector& axis_map); + const std::vector& axis_map) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // getRFactorRoot does not reapply any transformations. It simply searches // through the history of td from its root domain and tries to find a point // where we stop transforming axes marked as rfactor. This works because all // rfactor transformations are pushed to the begining of td's history by the // RFactor transformation itself. - static TensorDomain* getRFactorRoot(TensorDomain* td); + static TensorDomain* getRFactorRoot(TensorDomain* td) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 7680a3438e691..c57272754b0af 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -1,343 +1,363 @@ -#include -#include -#include -#include -#include - -#include - -namespace torch { -namespace jit { -namespace fuser { - -// Replay producer as consumer. -TensorDomain* TransformReplay::fullSelfReplay( - TensorDomain* self, - TensorDomain* self_copy) { - // Want producer root with no reductions, rfactor included - TensorDomain* self_root = self->rootDomain(); - - // Want full consumer root, even before rfactor - TensorDomain* self_copy_root = self_copy->rootDomain(); - - TORCH_INTERNAL_ASSERT( - self_root->nDims(), self_copy_root->nDims(), "Invalid self replay."); - - for (decltype(self_root->nDims()) i{0}; i < self_root->nDims(); i++) - TORCH_INTERNAL_ASSERT( - self_root->axis(i)->parallel_method() == - self_copy_root->axis(i)->parallel_method() && - self_root->axis(i)->isReduction() == - self_copy_root->axis(i)->isReduction() && - self_root->axis(i)->start() == self_copy_root->axis(i)->start(), - "Invalid self replay detected, root domain does not match."); - - std::vector axis_map(self_root->nDims()); - std::iota(axis_map.begin(), axis_map.end(), 0); - - // Finally replay producer as consumer on marked axes - - auto replayed = TransformIter::replay( - self_copy_root, TransformIter::getHistory(self), axis_map); - - return replayed; -} - -// Replay producer as consumer. -TensorDomain* TransformReplay::replayPasC( - TensorDomain* producer, - TensorDomain* consumer, - int consumer_compute_at_axis) { - if (consumer_compute_at_axis < 0) - consumer_compute_at_axis += (int)consumer->nDims() + 1; - TORCH_INTERNAL_ASSERT( - consumer_compute_at_axis >= 0 && - (unsigned int)consumer_compute_at_axis <= consumer->nDims(), - "Invalid axis in transform replayPasC."); - - // Consumer in rfactor cases is based off producer's rfactor root, not - // producer's root - TensorDomain* producer_rfactor_root = TransformIter::getRFactorRoot(producer); - - // Want full consumer root, even before rfactor - TensorDomain* consumer_root = TransformIter::getRoot(consumer); - - // We want to see which axes in the consumer root were modified to create axes - // < consumer_compute_at_axis - std::vector consumer_influence(consumer->nDims(), false); - for (int i = 0; i < consumer_compute_at_axis; i++) - consumer_influence[i] = true; - - // Check which axes in ref_root need to be modified to honor transformations - // to compute at axis - std::vector consumer_root_influence = - TransformIter::getRootInfluence(consumer, consumer_influence); - - // We have the influence on the consumer root, we need it on producer, we - // want to keep those axes that don't need to be modified by the replay - std::vector producer_rfactor_root_influence( - producer_rfactor_root->nDims(), false); - - // Map is based on producer - std::vector replay_axis_map(consumer_root->nDims(), -1); - // Setup producer_rfactor_root_influence vector on root for replay - size_t ip = 0, ic = 0; - - while (ip < producer_rfactor_root_influence.size() && - ic < consumer_root->nDims()) { - bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); - bool is_bcast = consumer_root->axis(ic)->isBroadcast(); - if (is_reduction) { - producer_rfactor_root_influence[ip++] = false; - } else if (is_bcast) { - replay_axis_map[ic++] = -1; - } else { - if (consumer_root_influence[ic]) { - replay_axis_map[ic] = ip; - } else { - replay_axis_map[ic] = -1; - } - producer_rfactor_root_influence[ip++] = consumer_root_influence[ic++]; - } - } - - for (decltype(producer_rfactor_root->nDims()) i{0}; - i < producer_rfactor_root->nDims(); - i++) - TORCH_INTERNAL_ASSERT( - !(producer_rfactor_root_influence[i] && - producer_rfactor_root->axis(i)->isRFactorProduct()), - "An illegal attempt to modify an rfactor axis detected."); - - // We should have hit the end of the consumer root domain - TORCH_INTERNAL_ASSERT( - ic == consumer_root->nDims() || - (ic < consumer_root->nDims() ? consumer_root->axis(ic)->isBroadcast() - : false), - "Error when trying to run replay, didn't reach end of consumer/target root."); - - TORCH_INTERNAL_ASSERT( - producer_rfactor_root_influence.size() == producer_rfactor_root->nDims(), - "Error detected during replay, expected matching sizes of influence map to root dimensions."); - - auto producer_root_influence = TransformIter::getRootInfluence( - producer_rfactor_root, producer_rfactor_root_influence); - - TensorDomain* producer_root = TransformIter::getRoot(producer_rfactor_root); - - std::vector producer_replay_map(producer_root->nDims()); - for (decltype(producer_replay_map.size()) i{0}; - i < producer_replay_map.size(); - i++) { - if (producer_root->axis(i)->isRFactorProduct()) { - producer_replay_map[i] = i; - } else { - producer_replay_map[i] = producer_root_influence[i] ? -1 : i; - } - } - - // Replay axes that won't be modified by transform replay - TensorDomain* producer_replay_root = TransformIter::replaySelf( - producer, TransformIter::getHistory(producer), producer_replay_map); - - // Record axes positions. - std::unordered_map new_position; - for (decltype(producer_replay_root->nDims()) i{0}; - i < producer_replay_root->nDims(); - i++) - new_position[producer_replay_root->axis(i)] = i; - - std::unordered_map root_axis_map; - // reorder producer_replay_root to respect replay_axis_map - for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); i++) { - if (replay_axis_map[i] == -1) - continue; - auto ax = producer_root->axis(replay_axis_map[i]); - TORCH_INTERNAL_ASSERT( - new_position.find(ax) != new_position.end(), - "Error hit during transform replay, could not find ", - ax, - " expected in root domain."); - root_axis_map[new_position[ax]] = replay_axis_map[i]; - } - - // root_axis_map is now mapping from producer_replay_root -> consumer_root - // Take producer_replay_root transform for all modified axes are in correct - // relative order, matching how it was in replay_axis_map - producer_replay_root = producer_replay_root->reorder(root_axis_map); - - // Finally replay producer as consumer on marked axes - TensorDomain* replayed = TransformIter::replay( - producer_replay_root, - TransformIter::getHistory(consumer), - replay_axis_map); - - TORCH_INTERNAL_ASSERT( - std::none_of( - replayed->domain().begin(), - replayed->domain().begin() + consumer_compute_at_axis, - [](IterDomain* id) { return id->isReduction(); }), - "Reduction axes found within consumer_compute_at_axis in replay of producer."); - - return replayed; -} - -// Replay consumer as producer. -TensorDomain* TransformReplay::replayCasP( - TensorDomain* consumer, - TensorDomain* producer, - int producer_compute_at_axis) { - if (producer_compute_at_axis < 0) - producer_compute_at_axis += (int)producer->nDims() + 1; - TORCH_INTERNAL_ASSERT( - producer_compute_at_axis >= 0 && - (unsigned int)producer_compute_at_axis <= producer->nDims(), - "Invalid axis in transform replayPasC."); - - // Want producer root with no reductions, rfactor included - TensorDomain* producer_rfactor_root = TransformIter::getRFactorRoot(producer); - TensorDomain* producer_root = TransformIter::getRoot(producer); - // Producer root still has reductions - - // Want full consumer root, even before rfactor - TensorDomain* consumer_root = TransformIter::getRoot(consumer); - - // We want to see which axes in the producer root were modified to create axes - // < producer_compute_at_axis - std::vector producer_influence(producer->nDims(), false); - for (int i = 0; i < producer_compute_at_axis; i++) - producer_influence[i] = true; - - // Check which axes in ref_root need to be modified to honor transformations - // to compute at axis - std::vector producer_root_influence = - TransformIter::getRootInfluence(producer, producer_influence); - - for (decltype(producer_root->nDims()) i{0}; i < producer_root->nDims(); i++) { - TORCH_INTERNAL_ASSERT( - !(producer_root_influence[i] && producer_root->axis(i)->isReduction()), - "Error during replay, likely due to an illegal bad computeAt."); - } - - std::vector producer_rfactor_root_influence = - TransformIter::replayInfluence( - TransformIter::getHistory(producer_rfactor_root), - producer_root_influence); - - // We have the influence on the producer root, we need it on consumer, we - // want to keep those axes that don't need to be modified by the replay - std::vector consumer_root_influence( - consumer->rootDomain()->nDims(), false); - - // Producer -> consumer axis map - std::vector replay_axis_map(producer_rfactor_root->nDims(), -1); - - // Setup consumer_root_influence vector on root for replay - decltype(consumer_root_influence.size()) ip = 0, ic = 0; - while (ic < consumer_root_influence.size() && - ip < producer_rfactor_root->nDims()) { - bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); - if (is_reduction) { - replay_axis_map[ip++] = -1; - continue; - } - if (producer_rfactor_root_influence[ip] && - !consumer_root->axis(ic)->isRFactorProduct()) { - replay_axis_map[ip] = ic; - } else { - replay_axis_map[ip] = -1; - } - consumer_root_influence[ic++] = producer_rfactor_root_influence[ip++]; - } - - // Unlike PasC if last axes in producer_rfactor_root is a reduction we won't - // be guarneteed that ip == producer_rfactor_root->nDims(), that's why we - // initialize replay_axis_map with -1 - - TORCH_INTERNAL_ASSERT( - consumer_root_influence.size() == consumer_root->nDims(), - "Error detected during replay, expected matching sizes of influence map to root dimensions."); - - std::vector consumer_replay_map(consumer_root->nDims()); - for (decltype(consumer_replay_map.size()) i{0}; - i < consumer_replay_map.size(); - i++) { - if (consumer_root->axis(i)->isRFactorProduct()) { - consumer_replay_map[i] = i; - } else { - consumer_replay_map[i] = consumer_root_influence[i] ? -1 : i; - } - } - - // Replay axes that won't be modified by transform replay - TensorDomain* consumer_replay_root = TransformIter::replaySelf( - consumer, TransformIter::getHistory(consumer), consumer_replay_map); - - // Record axes positions. - std::unordered_map new_position; - for (decltype(consumer_replay_root->nDims()) i{0}; - i < consumer_replay_root->nDims(); - i++) - new_position[consumer_replay_root->axis(i)] = i; - - std::unordered_map root_axis_map; - // reorder consumer_replay_root to respect replay_axis_map - for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); i++) { - if (replay_axis_map[i] == -1) - continue; - auto ax = consumer_root->axis(replay_axis_map[i]); - TORCH_INTERNAL_ASSERT( - new_position.find(ax) != new_position.end(), - "Error hit during transform replay, could not find ", - ax, - " expected in root domain."); - root_axis_map[new_position[ax]] = replay_axis_map[i]; - } - - auto replay_history = TransformIter::getHistory(producer); - auto rfactor_history = TransformIter::getHistory(producer_rfactor_root); - replay_history.erase( - replay_history.begin(), replay_history.begin() + rfactor_history.size()); - - consumer_replay_root = consumer_replay_root->reorder(root_axis_map); - // Finally replay consumer as producer on marked axes - - auto replayed = TransformIter::replay( - consumer_replay_root, replay_history, replay_axis_map); - - return replayed; -} - -// replay Producer as Consumer -TensorView* TransformReplay::replayPasC( - TensorView* producer, - TensorView* consumer, - int compute_at_axis) { - // If this is a reduction operation, we may call transform_replay on the same - // tensor view. When this happens, just return thet target view. - if (producer == consumer) - return producer; - - TensorDomain* td = - replayPasC(producer->domain(), consumer->domain(), compute_at_axis); - producer->setDomain(td); - return producer; -} - -TensorView* TransformReplay::replayCasP( - TensorView* consumer, - TensorView* producer, - int compute_at_axis) { - // If this is a reduction operation, we may call transform_replay on the same - // tensor view. When this happens, just return thet target view. - if (consumer == producer) - return consumer; - TensorDomain* td = - replayCasP(consumer->domain(), producer->domain(), compute_at_axis); - consumer->setDomain(td); - return consumer; -} - -} // namespace fuser -} // namespace jit -} // namespace torch +// #include +// #include +// #include +// #include +// #include + +// #include + +// namespace torch { +// namespace jit { +// namespace fuser { + +// // Replay producer as consumer. +// TensorDomain* TransformReplay::fullSelfReplay( +// TensorDomain* self, +// TensorDomain* self_copy) { +// // Want producer root with no reductions, rfactor included +// TensorDomain* self_root = self->rootDomain(); + +// // Want full consumer root, even before rfactor +// TensorDomain* self_copy_root = self_copy->rootDomain(); + +// TORCH_INTERNAL_ASSERT( +// self_root->nDims(), self_copy_root->nDims(), "Invalid self replay."); + +// for (decltype(self_root->nDims()) i{0}; i < self_root->nDims(); i++) +// TORCH_INTERNAL_ASSERT( +// self_root->axis(i)->parallel_method() == +// self_copy_root->axis(i)->parallel_method() && +// self_root->axis(i)->isReduction() == +// self_copy_root->axis(i)->isReduction() && +// self_root->axis(i)->start() == self_copy_root->axis(i)->start(), +// "Invalid self replay detected, root domain does not match."); + +// std::vector axis_map(self_root->nDims()); +// std::iota(axis_map.begin(), axis_map.end(), 0); + +// // Finally replay producer as consumer on marked axes + +// auto replayed = TransformIter::replay( +// self_copy_root, TransformIter::getHistory(self), axis_map); + +// return replayed; +// } + +// // Replay producer as consumer. +// TensorDomain* TransformReplay::replayPasC( +// TensorDomain* producer, +// TensorDomain* consumer, +// int consumer_compute_at_axis) { +// if (consumer_compute_at_axis < 0) +// consumer_compute_at_axis += (int)consumer->nDims() + 1; +// TORCH_INTERNAL_ASSERT( +// consumer_compute_at_axis >= 0 && +// (unsigned int)consumer_compute_at_axis <= consumer->nDims(), +// "Invalid axis in transform replayPasC."); + +// // Consumer in rfactor cases is based off producer's rfactor root, not +// // producer's root +// TensorDomain* producer_rfactor_root = +// TransformIter::getRFactorRoot(producer); + +// // Want full consumer root, even before rfactor +// TensorDomain* consumer_root = TransformIter::getRoot(consumer); + +// // We want to see which axes in the consumer root were modified to create +// axes +// // < consumer_compute_at_axis +// std::vector consumer_influence(consumer->nDims(), false); +// for (int i = 0; i < consumer_compute_at_axis; i++) +// consumer_influence[i] = true; + +// // Check which axes in ref_root need to be modified to honor +// transformations +// // to compute at axis +// std::vector consumer_root_influence = +// TransformIter::getRootInfluence(consumer, consumer_influence); + +// // We have the influence on the consumer root, we need it on producer, we +// // want to keep those axes that don't need to be modified by the replay +// std::vector producer_rfactor_root_influence( +// producer_rfactor_root->nDims(), false); + +// // Map is based on producer +// std::vector replay_axis_map(consumer_root->nDims(), -1); +// // Setup producer_rfactor_root_influence vector on root for replay +// size_t ip = 0, ic = 0; + +// while (ip < producer_rfactor_root_influence.size() && +// ic < consumer_root->nDims()) { +// bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); +// bool is_bcast = consumer_root->axis(ic)->isBroadcast(); +// if (is_reduction) { +// producer_rfactor_root_influence[ip++] = false; +// } else if (is_bcast) { +// replay_axis_map[ic++] = -1; +// } else { +// if (consumer_root_influence[ic]) { +// replay_axis_map[ic] = ip; +// } else { +// replay_axis_map[ic] = -1; +// } +// producer_rfactor_root_influence[ip++] = consumer_root_influence[ic++]; +// } +// } + +// for (decltype(producer_rfactor_root->nDims()) i{0}; +// i < producer_rfactor_root->nDims(); +// i++) +// TORCH_INTERNAL_ASSERT( +// !(producer_rfactor_root_influence[i] && +// producer_rfactor_root->axis(i)->isRFactorProduct()), +// "An illegal attempt to modify an rfactor axis detected."); + +// // We should have hit the end of the consumer root domain +// TORCH_INTERNAL_ASSERT( +// ic == consumer_root->nDims() || +// (ic < consumer_root->nDims() ? +// consumer_root->axis(ic)->isBroadcast() +// : false), +// "Error when trying to run replay, didn't reach end of consumer/target +// root."); + +// TORCH_INTERNAL_ASSERT( +// producer_rfactor_root_influence.size() == +// producer_rfactor_root->nDims(), "Error detected during replay, expected +// matching sizes of influence map to root dimensions."); + +// auto producer_root_influence = TransformIter::getRootInfluence( +// producer_rfactor_root, producer_rfactor_root_influence); + +// TensorDomain* producer_root = +// TransformIter::getRoot(producer_rfactor_root); + +// std::vector producer_replay_map(producer_root->nDims()); +// for (decltype(producer_replay_map.size()) i{0}; +// i < producer_replay_map.size(); +// i++) { +// if (producer_root->axis(i)->isRFactorProduct()) { +// producer_replay_map[i] = i; +// } else { +// producer_replay_map[i] = producer_root_influence[i] ? -1 : i; +// } +// } + +// // Replay axes that won't be modified by transform replay +// TensorDomain* producer_replay_root = TransformIter::replaySelf( +// producer, TransformIter::getHistory(producer), producer_replay_map); + +// // Record axes positions. +// std::unordered_map new_position; +// for (decltype(producer_replay_root->nDims()) i{0}; +// i < producer_replay_root->nDims(); +// i++) +// new_position[producer_replay_root->axis(i)] = i; + +// std::unordered_map root_axis_map; +// // reorder producer_replay_root to respect replay_axis_map +// for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); +// i++) { +// if (replay_axis_map[i] == -1) +// continue; +// auto ax = producer_root->axis(replay_axis_map[i]); +// TORCH_INTERNAL_ASSERT( +// new_position.find(ax) != new_position.end(), +// "Error hit during transform replay, could not find ", +// ax, +// " expected in root domain."); +// root_axis_map[new_position[ax]] = replay_axis_map[i]; +// } + +// // root_axis_map is now mapping from producer_replay_root -> consumer_root +// // Take producer_replay_root transform for all modified axes are in correct +// // relative order, matching how it was in replay_axis_map +// producer_replay_root = producer_replay_root->reorder(root_axis_map); + +// // Finally replay producer as consumer on marked axes +// TensorDomain* replayed = TransformIter::replay( +// producer_replay_root, +// TransformIter::getHistory(consumer), +// replay_axis_map); + +// TORCH_INTERNAL_ASSERT( +// std::none_of( +// replayed->domain().begin(), +// replayed->domain().begin() + consumer_compute_at_axis, +// [](IterDomain* id) { return id->isReduction(); }), +// "Reduction axes found within consumer_compute_at_axis in replay of +// producer."); + +// return replayed; +// } + +// // Replay consumer as producer. +// TensorDomain* TransformReplay::replayCasP( +// TensorDomain* consumer, +// TensorDomain* producer, +// int producer_compute_at_axis) { +// if (producer_compute_at_axis < 0) +// producer_compute_at_axis += (int)producer->nDims() + 1; +// TORCH_INTERNAL_ASSERT( +// producer_compute_at_axis >= 0 && +// (unsigned int)producer_compute_at_axis <= producer->nDims(), +// "Invalid axis in transform replayPasC."); + +// // Want producer root with no reductions, rfactor included +// TensorDomain* producer_rfactor_root = +// TransformIter::getRFactorRoot(producer); TensorDomain* producer_root = +// TransformIter::getRoot(producer); +// // Producer root still has reductions + +// // Want full consumer root, even before rfactor +// TensorDomain* consumer_root = TransformIter::getRoot(consumer); + +// // We want to see which axes in the producer root were modified to create +// axes +// // < producer_compute_at_axis +// std::vector producer_influence(producer->nDims(), false); +// for (int i = 0; i < producer_compute_at_axis; i++) +// producer_influence[i] = true; + +// // Check which axes in ref_root need to be modified to honor +// transformations +// // to compute at axis +// std::vector producer_root_influence = +// TransformIter::getRootInfluence(producer, producer_influence); + +// for (decltype(producer_root->nDims()) i{0}; i < producer_root->nDims(); +// i++) { +// TORCH_INTERNAL_ASSERT( +// !(producer_root_influence[i] && +// producer_root->axis(i)->isReduction()), "Error during replay, likely +// due to an illegal bad computeAt."); +// } + +// std::vector producer_rfactor_root_influence = +// TransformIter::replayInfluence( +// TransformIter::getHistory(producer_rfactor_root), +// producer_root_influence); + +// // We have the influence on the producer root, we need it on consumer, we +// // want to keep those axes that don't need to be modified by the replay +// std::vector consumer_root_influence( +// consumer->rootDomain()->nDims(), false); + +// // Producer -> consumer axis map +// std::vector replay_axis_map(producer_rfactor_root->nDims(), -1); + +// // Setup consumer_root_influence vector on root for replay +// decltype(consumer_root_influence.size()) ip = 0, ic = 0; +// while (ic < consumer_root_influence.size() && +// ip < producer_rfactor_root->nDims()) { +// bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); +// if (is_reduction) { +// replay_axis_map[ip++] = -1; +// continue; +// } +// if (producer_rfactor_root_influence[ip] && +// !consumer_root->axis(ic)->isRFactorProduct()) { +// replay_axis_map[ip] = ic; +// } else { +// replay_axis_map[ip] = -1; +// } +// consumer_root_influence[ic++] = producer_rfactor_root_influence[ip++]; +// } + +// // Unlike PasC if last axes in producer_rfactor_root is a reduction we +// won't +// // be guarneteed that ip == producer_rfactor_root->nDims(), that's why we +// // initialize replay_axis_map with -1 + +// TORCH_INTERNAL_ASSERT( +// consumer_root_influence.size() == consumer_root->nDims(), +// "Error detected during replay, expected matching sizes of influence map +// to root dimensions."); + +// std::vector consumer_replay_map(consumer_root->nDims()); +// for (decltype(consumer_replay_map.size()) i{0}; +// i < consumer_replay_map.size(); +// i++) { +// if (consumer_root->axis(i)->isRFactorProduct()) { +// consumer_replay_map[i] = i; +// } else { +// consumer_replay_map[i] = consumer_root_influence[i] ? -1 : i; +// } +// } + +// // Replay axes that won't be modified by transform replay +// TensorDomain* consumer_replay_root = TransformIter::replaySelf( +// consumer, TransformIter::getHistory(consumer), consumer_replay_map); + +// // Record axes positions. +// std::unordered_map new_position; +// for (decltype(consumer_replay_root->nDims()) i{0}; +// i < consumer_replay_root->nDims(); +// i++) +// new_position[consumer_replay_root->axis(i)] = i; + +// std::unordered_map root_axis_map; +// // reorder consumer_replay_root to respect replay_axis_map +// for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); +// i++) { +// if (replay_axis_map[i] == -1) +// continue; +// auto ax = consumer_root->axis(replay_axis_map[i]); +// TORCH_INTERNAL_ASSERT( +// new_position.find(ax) != new_position.end(), +// "Error hit during transform replay, could not find ", +// ax, +// " expected in root domain."); +// root_axis_map[new_position[ax]] = replay_axis_map[i]; +// } + +// auto replay_history = TransformIter::getHistory(producer); +// auto rfactor_history = TransformIter::getHistory(producer_rfactor_root); +// replay_history.erase( +// replay_history.begin(), replay_history.begin() + +// rfactor_history.size()); + +// consumer_replay_root = consumer_replay_root->reorder(root_axis_map); +// // Finally replay consumer as producer on marked axes + +// auto replayed = TransformIter::replay( +// consumer_replay_root, replay_history, replay_axis_map); + +// return replayed; +// } + +// // replay Producer as Consumer +// TensorView* TransformReplay::replayPasC( +// TensorView* producer, +// TensorView* consumer, +// int compute_at_axis) { +// // If this is a reduction operation, we may call transform_replay on the +// same +// // tensor view. When this happens, just return thet target view. +// if (producer == consumer) +// return producer; + +// TensorDomain* td = +// replayPasC(producer->domain(), consumer->domain(), compute_at_axis); +// producer->setDomain(td); +// return producer; +// } + +// TensorView* TransformReplay::replayCasP( +// TensorView* consumer, +// TensorView* producer, +// int compute_at_axis) { +// // If this is a reduction operation, we may call transform_replay on the +// same +// // tensor view. When this happens, just return thet target view. +// if (consumer == producer) +// return consumer; +// TensorDomain* td = +// replayCasP(consumer->domain(), producer->domain(), compute_at_axis); +// consumer->setDomain(td); +// return consumer; +// } + +// } // namespace fuser +// } // namespace jit +// } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index c5affe07d48a0..96c353d44ad27 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -124,31 +124,41 @@ struct TORCH_CUDA_API TransformReplay { // Self replay. static TensorDomain* fullSelfReplay( TensorDomain* self, - TensorDomain* self_copy); + TensorDomain* self_copy) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Replay producer as consumer. static TensorDomain* replayPasC( TensorDomain* producer, TensorDomain* consumer, - int compute_at_axis); + int compute_at_axis) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Replay producer as consumer. static TensorView* replayPasC( TensorView* producer, TensorView* consumer, - int compute_at_axis); + int compute_at_axis) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Replay producer as consumer. static TensorDomain* replayCasP( TensorDomain* consumer, TensorDomain* producer, - int compute_at_axis); + int compute_at_axis) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } // Replay producer as consumer. static TensorView* replayCasP( TensorView* consumer, TensorView* producer, - int compute_at_axis); + int compute_at_axis) { + TORCH_INTERNAL_ASSERT(false, "NIY."); + } }; } // namespace fuser From d0989d09092ad60c50a0e796935dace9ce3718a9 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 24 May 2020 10:28:23 -0400 Subject: [PATCH 4/8] Refactor IterDomain, TensorDomain, TensorView. IterDomain holds history of merge/split now. --- test/cpp/jit/test_gpu.cpp | 175 ++++----- torch/csrc/jit/codegen/cuda/arith.cpp | 4 +- torch/csrc/jit/codegen/cuda/fusion.cpp | 6 +- .../jit/codegen/cuda/ir_interface_nodes.h | 19 +- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 45 ++- torch/csrc/jit/codegen/cuda/ir_iostream.cpp | 5 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 350 +++++++++++------- torch/csrc/jit/codegen/cuda/lower2device.cpp | 24 +- torch/csrc/jit/codegen/cuda/parser.cpp | 2 +- .../jit/codegen/cuda/predicate_compute.cpp | 8 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 32 +- 11 files changed, 390 insertions(+), 280 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index f8e271527c140..eeb65c5d50910 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -380,7 +380,7 @@ void testGPU_FusionTVSplit() { static_cast(outer)->getBinaryOpType() == BinaryOpType::CeilDiv && static_cast(outer)->lhs()->sameAs( - tv->getRootDomain()->axis(2)->extent()) && + tv->getRootDomain()[2]->extent()) && static_cast(static_cast(outer)->rhs()) ->sameAs(new Int(2))); @@ -404,9 +404,9 @@ void testGPU_FusionTVMerge() { tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp && static_cast(axisOp)->getBinaryOpType() == BinaryOpType::Mul && static_cast(axisOp)->lhs() == - tv->getRootDomain()->axis(1)->extent() && + tv->getRootDomain()[1]->extent() && static_cast(axisOp)->rhs() == - tv->getRootDomain()->axis(2)->extent()); + tv->getRootDomain()[2]->extent()); } void testGPU_FusionTVReorder() { @@ -1903,91 +1903,94 @@ void testGPU_FusionCastOps() { // We want split/merge/reorder all tested both on and off rfactor domains, also // want compute at into the rfactor domain, and into its consumer void testGPU_FusionRFactorReplay() { - Fusion fusion; - FusionGuard fg(&fusion); - - // Set up your input tensor views - TensorView* tv0 = makeDummyTensor(2); - - // Register your inputs - fusion.addInput(tv0); - - // Do math with it, it returns a `Val*` but can be static_casted back to - // TensorView - TensorView* tv1 = static_cast(sum(tv0, {1})); - // tv1[I0, R1] - tv1->split(0, 32); - // tv1[I0o, I0i{32}, R1] - tv1->split(0, 16); - // tv1[I0oo, I0oi{16}, I0i{32}, R1] - tv1->split(-1, 8); - // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}] - tv1->split(-2, 4); - // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}] - - tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}}); - // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}] - - tv1->merge(0); - tv1->merge(-2); - - // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}] - TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0}); - TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), {0}); - // new_domain[R(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}] - // new_domain2[ I0oi{16}, , I0oo*I0i{32}, R1oi{4}] - - // Move rfactor axis to end, keep iter rfactor axis - auto reordered_new_domain = new_domain->reorder({{0, -1}, {2, 2}}); - // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] - - TensorDomain* casp = - TransformReplay::replayCasP(new_domain2, reordered_new_domain, 2); - // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] - // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] - - casp = casp->split(1, 2); - // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}] - // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf, - // R(R1oo*R1i{8})rf] - TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2); - // pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf, + // Fusion fusion; + // FusionGuard fg(&fusion); + + // // Set up your input tensor views + // TensorView* tv0 = makeDummyTensor(2); + + // // Register your inputs + // fusion.addInput(tv0); + + // // Do math with it, it returns a `Val*` but can be static_casted back to + // // TensorView + // TensorView* tv1 = static_cast(sum(tv0, {1})); + // // tv1[I0, R1] + // tv1->split(0, 32); + // // tv1[I0o, I0i{32}, R1] + // tv1->split(0, 16); + // // tv1[I0oo, I0oi{16}, I0i{32}, R1] + // tv1->split(-1, 8); + // // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}] + // tv1->split(-2, 4); + // // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}] + + // tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}}); + // // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}] + + // tv1->merge(0); + // tv1->merge(-2); + + // // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}] + // TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0}); + // TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), + // {0}); + // // new_domain[R(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}] + // // new_domain2[ I0oi{16}, , I0oo*I0i{32}, + // R1oi{4}] + + // // Move rfactor axis to end, keep iter rfactor axis + // auto reordered_new_domain = new_domain->reorder({{0, -1}, {2, 2}}); + // // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, // R(R1oo*R1i{8})rf] - TORCH_CHECK( - new_domain->nDims() - 1 == new_domain2->nDims(), - casp->nDims() == new_domain2->nDims() + 1, - pasc->nDims() == new_domain->nDims() + 1, - "Error in rfactor, number of dimensions is not correct."); - - TORCH_CHECK( - !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) && - !new_domain->sameAs(new_domain2) && - !tv1->domain()->sameAs(new_domain) && - !tv1->domain()->sameAs(new_domain2), - "Error in rfactor, number of dimensions is not correct."); - - auto dom = new_domain->rootDomain()->domain(); - TORCH_CHECK( - !new_domain->rootDomain()->axis(0)->isReduction() && - std::any_of( - dom.begin(), - dom.end(), - [](IterDomain* id) { return id->isReduction(); }) && - std::any_of( - dom.begin(), - dom.end(), - [](IterDomain* id) { return id->isRFactorProduct(); }), - "Error in rFactor, there seems to be something wrong in root domain."); - - auto dom2 = new_domain2->rootDomain()->domain(); - TORCH_CHECK( - !new_domain2->rootDomain()->axis(0)->isReduction() && - std::any_of( - dom2.begin(), - dom2.end(), - [](IterDomain* id) { return id->isReduction(); }), - "Error in rFactor, there seems to be something wrong in root domain."); + // TensorDomain* casp = + // TransformReplay::replayCasP(new_domain2, reordered_new_domain, 2); + // // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] + // // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] + + // casp = casp->split(1, 2); + // // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}] + // // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf, + // // R(R1oo*R1i{8})rf] + // TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2); + // // pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf, + // // R(R1oo*R1i{8})rf] + + // TORCH_CHECK( + // new_domain->nDims() - 1 == new_domain2->nDims(), + // casp->nDims() == new_domain2->nDims() + 1, + // pasc->nDims() == new_domain->nDims() + 1, + // "Error in rfactor, number of dimensions is not correct."); + + // TORCH_CHECK( + // !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) && + // !new_domain->sameAs(new_domain2) && + // !tv1->domain()->sameAs(new_domain) && + // !tv1->domain()->sameAs(new_domain2), + // "Error in rfactor, number of dimensions is not correct."); + + // auto dom = new_domain->rootDomain()->domain(); + // TORCH_CHECK( + // !new_domain->rootDomain()->axis(0)->isReduction() && + // std::any_of( + // dom.begin(), + // dom.end(), + // [](IterDomain* id) { return id->isReduction(); }) && + // std::any_of( + // dom.begin(), + // dom.end(), + // [](IterDomain* id) { return id->isRFactorProduct(); }), + // "Error in rFactor, there seems to be something wrong in root domain."); + + // auto dom2 = new_domain2->rootDomain()->domain(); + // TORCH_CHECK( + // !new_domain2->rootDomain()->axis(0)->isReduction() && + // std::any_of( + // dom2.begin(), + // dom2.end(), + // [](IterDomain* id) { return id->isReduction(); }), + // "Error in rFactor, there seems to be something wrong in root domain."); } // Start off simple, block on the outer dim diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index dc21dbdda1ab1..c85b803a8ac4e 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -235,7 +235,7 @@ TORCH_CUDA_API Val* andOp(Val* v1, Val* v2) { namespace { // TODO: How do we adjust this so we can reduce to a single scalar value? TensorView* newForReduction(TensorView* tv, std::vector axes) { - auto orig_domain = tv->getRootDomain()->noReductions(); + auto orig_domain = TensorDomain::noReductions(tv->getRootDomain()); std::set axes_set(axes.begin(), axes.end()); std::vector new_domain; @@ -281,7 +281,7 @@ Val* reductionOp( TensorView* tv = static_cast(v1); TORCH_CHECK( - tv->getRootDomain() == tv->domain(), + TensorDomain::sameAs(tv->getRootDomain(), tv->domain()->domain()), "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt."); std::vector uint_axes; diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 750c8e32e07d9..6623392ec40a0 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -139,9 +139,9 @@ void Fusion::addOutput(Val* const output) { assertInFusion(output, "Cannot register output "); if (output->getValType().value() == ValType::TensorView) { TensorView* tv = static_cast(output); - if (tv->getRootDomain() - ->hasBroadcast()) // Go to the root as we can merge bcast and - // non-bcast dims, making a non-bcast dim. + if (TensorDomain::hasBroadcast(tv->getRootDomain())) + // Go to the root as we can merge bcast and + // non-bcast dims, making a non-bcast dim. TORCH_CHECK( // Should we warn instead? false, output, diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 80a9517803cc2..4c167e4dd18af 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -209,8 +209,8 @@ struct TORCH_CUDA_API TensorView : public Val { return compute_at_view_; } - // domain() accessors - std::vector::size_type nDims() const; + size_t nDims() const; + IterDomain* axis(int pos) const; // Return compute at axis relative to this domain @@ -231,7 +231,7 @@ struct TORCH_CUDA_API TensorView : public Val { return compute_at_view_->getComputeAtAxis(getComputeAtRelPos(pos)); } - TensorDomain* getRootDomain() const; + const std::vector& getRootDomain() const; // Compute this TensorView relative to another tensor at axis TensorView* computeAt(TensorView* consumer, int axis); @@ -244,10 +244,15 @@ struct TORCH_CUDA_API TensorView : public Val { // Split "axis" into 2 axes where the inner axes is size of "factor" // and outer axis is size axis.size() / factor - TensorView* split(int axis, int factor); + TensorView* split(int axis, unsigned int factor); + + // Merge axis_o and axis_i into 1 IterDomain + TensorView* merge(int axis_o, int axis_i); - // Merge "axis" and "axis+1" into 1 dimension - TensorView* merge(int axis); + // Merge axis and axis+1 into 1 IterDomain + TensorView* merge(int axis) { + return merge(axis, axis + 1); + } // Reorder axes according to old2new[old_pos] = new_pos TensorView* reorder(const std::unordered_map& old2new); @@ -268,7 +273,7 @@ struct TORCH_CUDA_API TensorView : public Val { } friend TORCH_CUDA_API TransformReplay; - friend TORCH_CUDA_API TransformIter; + // friend TORCH_CUDA_API TransformIter; friend TORCH_CUDA_API OptOutMutator; friend TORCH_CUDA_API GPULower; friend TORCH_CUDA_API LoopNestGenerator; diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 3c87025fe479d..28845f2e5bfc0 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -242,6 +242,11 @@ struct TORCH_CUDA_API IterDomain : public Val { isBroadcast()); } + static IterDomain* merge(IterDomain* outer, IterDomain* inner); + static std::pair split( + IterDomain* in, + unsigned int factor); + bool isReduction() const noexcept { return is_reduction_domain_; } @@ -361,43 +366,65 @@ struct TORCH_CUDA_API TensorDomain : public Val { bool sameAs(const TensorDomain* const other) const; + static bool sameAs( + const std::vector& lhs, + const std::vector& rhs); + const std::vector& domain() const noexcept { return domain_; } bool hasReduction() const; bool hasBroadcast() const; - bool hasRFactor() const; const std::vector& noReductions() const noexcept { return noReductionDomain_; } - const std::vector& noBroadcast() const noexcept { + const std::vector& noBroadcasts() const noexcept { return noBCastDomain_; } + const std::vector& rootDomain() const noexcept { + return root_domain_; + }; + + void resetDomains() { + noReductionDomain_ = noReductions(domain_); + noBCastDomain_ = noBroadcasts(domain_); + } + // i here is int, as we want to accept negative value and ::size_type can be a // uint. IterDomain* axis(int i) const; // Split "axis" into 2 axes where the inner axes is size of "factor" // and outer axis is size axis.size() / factor - TensorDomain* split(int axis, int factor); + void split(int axis, unsigned int factor); - // Merge "axis" and "axis+1" into 1 dimension - TensorDomain* merge(int axis); + // Merge axis_o and axis_i. axis_i is the fast changing dimension. Resulting + // axis is by default placed at original position axis_o + void merge(int axis_o, int axis_i); // Reorder axes according to map[old_pos] = new_pos - TensorDomain* reorder(const std::unordered_map& old2new); + void reorder(const std::unordered_map& old2new); + + static std::vector orderedAs( + const std::vector& td, + const std::unordered_map& old2new); + + static std::vector noReductions(const std::vector&); + static std::vector noBroadcasts(const std::vector&); + + static bool hasBroadcast(const std::vector&); + static bool hasReduction(const std::vector&); // pair is in order where second is the consumer of first std::pair rFactor(const std::vector& axes); - TensorDomain* rootDomain(); - private: - const std::vector domain_; + const std::vector root_domain_; + std::vector domain_; std::vector noBCastDomain_; std::vector noReductionDomain_; }; diff --git a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp index aab818d386537..483c04c7f0d81 100644 --- a/torch/csrc/jit/codegen/cuda/ir_iostream.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_iostream.cpp @@ -37,9 +37,8 @@ void IRPrinter::printHeader(Fusion* fusion, const std::string& kernel_name_) { switch (val->getValType().value()) { case (ValType::TensorView): os << "Tensor<" << val->getDataType().value() << ", " - << static_cast(val) - ->getRootDomain() - ->noReductions() + << TensorDomain::noReductions( + static_cast(val)->getRootDomain()) .size() << "> T" << val->name(); break; diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 6f16dce6a52ae..a0d17bc25883f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -249,6 +249,70 @@ bool IterDomain::sameAs(const IterDomain* const other) const { return is_same; } +IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) { + TORCH_CHECK( + outer->start()->isZeroInt() && inner->start()->isZeroInt(), + "Merging IterDomains with starting values that aren't 0 is not supported at this time."); + TORCH_CHECK( + outer->isReduction() == inner->isReduction(), + "Merging IterDomains requires that their iteration types match."); + TORCH_CHECK( + outer->parallel_method() == inner->parallel_method(), + "Merging IterDomains requires that their parallel types match."); + + Val* merged_id_size = mul(outer->extent(), inner->extent()); + IterDomain* merged_id = new IterDomain( + new Int(0), + static_cast(merged_id_size), + outer->parallel_method(), + outer->isReduction(), + outer->isRFactorProduct() || inner->isRFactorProduct(), + outer->isBroadcast() && inner->isBroadcast()); + + new Merge(merged_id, outer, inner); + + return merged_id; +} + +std::pair IterDomain::split( + IterDomain* in, + unsigned int factor) { + TORCH_CHECK( + in->start()->isZeroInt(), + "Splitting IterDomains with starting values that aren't 0 is not supported at this time."); + + if (in->parallel_method() != ParallelType::Serial) + TORCH_CHECK( + false, + "Splitting an axis of non-Serial iteration is not supported at this time." + " Parallelization strategy must be set after calling split."); + + Int* fact = new Int(factor); + // outer loop size + Val* vo = ceilDiv(in->extent(), fact); + Int* so = static_cast(vo); + + // outer loop IterDomain + IterDomain* ido = new IterDomain( + new Int(0), + so, + in->parallel_method(), + in->isReduction(), + in->isRFactorProduct(), + in->isBroadcast()); + + // inner loop IterDomain + IterDomain* idi = new IterDomain( + new Int(0), + fact, + in->parallel_method(), + in->isReduction(), + in->isRFactorProduct(), + in->isBroadcast()); + new Split(ido, idi, in, fact); + return {ido, idi}; +} + Val* IterDomain::extent() const { if (isThread()) { if (extent_->getValType() == ValType::Scalar) @@ -262,13 +326,9 @@ Val* IterDomain::extent() const { } TensorDomain::TensorDomain(std::vector _domain) - : Val(ValType::TensorDomain), domain_(std::move(_domain)) { - for (IterDomain* id : domain_) { - if (!id->isReduction()) - noReductionDomain_.push_back(id); - if (!id->isBroadcast()) - noBCastDomain_.push_back(id); - } + : Val(ValType::TensorDomain), root_domain_(std::move(_domain)) { + domain_ = std::vector(root_domain_.begin(), root_domain_.end()); + resetDomains(); } bool TensorDomain::sameAs(const TensorDomain* const other) const { @@ -282,6 +342,19 @@ bool TensorDomain::sameAs(const TensorDomain* const other) const { return true; } +bool TensorDomain::sameAs( + const std::vector& lhs, + const std::vector& rhs) { + if (lhs.size() != rhs.size()) + return false; + size_t i = 0; + for (auto null_ : lhs) { + if (!lhs[i]->sameAs(rhs[i])) + return false; + } + return true; +} + bool TensorDomain::hasReduction() const { return noReductionDomain_.size() != domain_.size(); } @@ -312,7 +385,7 @@ IterDomain* TensorDomain::axis(int i) const { // Split "axis" into 2 axes where the inner axes is size of "factor" // and outer axis is size axis.extent() / factor -TensorDomain* TensorDomain::split(int axis_, int factor) { +void TensorDomain::split(int axis_, unsigned int factor) { if (axis_ < 0) axis_ += nDims(); @@ -321,112 +394,63 @@ TensorDomain* TensorDomain::split(int axis_, int factor) { "Tried to split on axis outside TensorDomain's range."); IterDomain* id = axis(axis_); - - TORCH_CHECK( - id->start()->isZeroInt(), - "Splitting IterDomains with starting values that aren't 0, is not supported at this time."); - - if (id->parallel_method() != ParallelType::Serial) - TORCH_CHECK( - false, - "Splitting an axis of non-Serial iteration is not supported at this time." - " Parallelization strategy must be set after calling split."); - - std::vector new_domain; - - Int* fact = new Int(factor); - for (decltype(nDims()) i = 0; i < nDims(); i++) { - if (i != (unsigned int)axis_) - new_domain.push_back(axis(i)); - else { - // outer loop size - Val* vo = ceilDiv(id->extent(), fact); - Int* so = static_cast(vo); - - // outer loop IterDomain - IterDomain* ido = new IterDomain( - new Int(0), - so, - id->parallel_method(), - id->isReduction(), - id->isRFactorProduct(), - id->isBroadcast()); - new_domain.push_back(ido); - - // inner loop IterDomain - IterDomain* idi = new IterDomain( - new Int(0), - fact, - id->parallel_method(), - id->isReduction(), - id->isRFactorProduct(), - id->isBroadcast()); - new_domain.push_back(idi); - } - } - - TensorDomain* split_td = new TensorDomain(new_domain); - { TORCH_INTERNAL_ASSERT(false, "NIY."); } - // new Split(split_td, this, axis_, fact); // For record keeping - return split_td; + auto split_ids = IterDomain::split(id, factor); + domain_.erase(domain_.begin() + axis_); + domain_.insert(domain_.begin() + axis_, split_ids.second); + domain_.insert(domain_.begin() + axis_, split_ids.first); + resetDomains(); } // Merge "axis" and "axis+1" into 1 dimension -TensorDomain* TensorDomain::merge(int axis_) { - if (axis_ < 0) - axis_ += nDims(); - - TORCH_CHECK( - axis_ >= 0 && (unsigned int)(axis_ + 1) < nDims(), - "Trying to merge axis_ outside of TensorView's range."); +void TensorDomain::merge(int axis_o, int axis_i) { + if (axis_o < 0) + axis_o += nDims(); - IterDomain* first = axis(axis_); - IterDomain* second = axis(axis_ + 1); + if (axis_i < 0) + axis_i += nDims(); TORCH_CHECK( - first->start()->isZeroInt() && second->start()->isZeroInt(), - "Merging IterDomains with starting values that aren't 0, is not supported at this time."); - TORCH_CHECK( - first->isReduction() == second->isReduction(), - "Merging domains requires that they're either both a reduction axis_, or both an iteration axis_."); + axis_o >= 0 && (unsigned int)(axis_o + 1) < nDims() && axis_i >= 0 && + (unsigned int)(axis_i + 1) < nDims(), + "Invalid merge detected, either one or both axes are outside of TensorView's range."); + TORCH_CHECK( - first->parallel_method() == second->parallel_method(), - "Axes must have matching parallel types."); + axis_o != axis_i, + "Invalid merge detected, axes provided are the same axis."); - Val* merged_id_size = mul(first->extent(), second->extent()); - IterDomain* merged_id = new IterDomain( - new Int(0), - static_cast(merged_id_size), - first->parallel_method(), - first->isReduction(), - first->isRFactorProduct() || second->isRFactorProduct(), - first->isBroadcast() && second->isBroadcast()); - - std::vector new_domain; - for (decltype(nDims()) i = 0; i < nDims(); i++) { - if (i < (unsigned int)axis_ || i > (unsigned int)(axis_ + 1)) - new_domain.push_back(axis(i)); - else if (i == (unsigned int)axis_) { - new_domain.push_back(merged_id); - } + if (axis_o > axis_i) { + auto tmp = axis_i; + axis_i = axis_o; + axis_o = tmp; } - TensorDomain* merged_td = new TensorDomain(new_domain); - { TORCH_INTERNAL_ASSERT(false, "NIY."); } - // new Merge(merged_td, this, axis_); // For record keeping - return merged_td; + + IterDomain* first = axis(axis_o); + IterDomain* second = axis(axis_i); + + IterDomain* merged_id = IterDomain::merge(first, second); + + domain_.erase(domain_.begin() + axis_i); + domain_.erase(domain_.begin() + axis_o); + domain_.insert(domain_.begin() + axis_o, merged_id); + resetDomains(); } // Reorder axes according to map[old_pos] = new_pos -TensorDomain* TensorDomain::reorder( +void TensorDomain::reorder(const std::unordered_map& old2new_) { + domain_ = orderedAs(domain_, old2new_); + resetDomains(); +} + +std::vector TensorDomain::orderedAs( + const std::vector& dom, const std::unordered_map& old2new_) { - // START VALIDATION CHECKS // Eventhough these checks are already in TensorView, we want to redo them as // we can enter this function from other places, not through TensorView // adjust based on negative values (any negative values gets nDims added to // it) std::unordered_map old2new; - auto ndims = nDims(); + auto ndims = dom.size(); std::transform( old2new_.begin(), old2new_.end(), @@ -449,9 +473,7 @@ TensorDomain* TensorDomain::reorder( TORCH_CHECK( !out_of_range, - "Reorder axes are not within the number of dimensions of this domain, ", - this, - "."); + "Reorder axes are not within the number of dimensions of the provided domain."); // Going to use sets, to see if any duplicate values are in the map. @@ -483,12 +505,14 @@ TensorDomain* TensorDomain::reorder( std::vector new2old(ndims, -1); - // Go through each old and new position, make sure they're within 0-ndims + // Go through each old and new position, make sure they're within [0, ndims) for (std::pair elem : old2new) { int old_pos = elem.first; int new_pos = elem.second; - assert(old_pos >= 0 && old_pos < ndims && new_pos >= 0 && new_pos < ndims); + TORCH_INTERNAL_ASSERT( + old_pos >= 0 && old_pos < ndims && new_pos >= 0 && new_pos < ndims, + "Error occured in reorder, somehow axes are not in expected range."); if (new2old[new_pos] != -1) TORCH_CHECK(false, "Reorder found duplicate destination positions."); @@ -496,13 +520,16 @@ TensorDomain* TensorDomain::reorder( new2old[new_pos] = old_pos; } + // old_positions that already have a new position std::set old_positions(new2old.begin(), new2old.end()); old_positions.erase(-1); + // Make sure we have all of them, and no duplicates were found if (old_positions.size() != old2new.size()) TORCH_INTERNAL_ASSERT( false, "Reorder found duplicate destination positions."); + // All available new positions std::set all_positions; for (decltype(ndims) i{0}; i < ndims; i++) all_positions.insert(i); @@ -530,56 +557,99 @@ TensorDomain* TensorDomain::reorder( new2old.begin(), new2old.end(), std::back_inserter(reordered_domain), - [this](int i) -> IterDomain* { return this->axis(i); }); + [dom](int i) -> IterDomain* { return dom[i]; }); - TensorDomain* reordered_td = new TensorDomain(reordered_domain); - return reordered_td; + return reordered_domain; } -// pair is in order where second is the consumer of first -std::pair TensorDomain::rFactor( - const std::vector& axes_) { - std::vector axes(axes_.size()); +std::vector TensorDomain::noReductions( + const std::vector& td) { + size_t size_out = 0; + for (auto id : td) + if (!id->isReduction()) + size_out++; + std::vector noReductionDomain(size_out); - auto ndims = nDims(); - std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) { - return i < 0 ? i + ndims : i; - }); + int it = 0; + for (auto id : td) + if (!id->isReduction()) + noReductionDomain[it++] = id; - TORCH_CHECK( - std::none_of( - axes.begin(), - axes.end(), - [ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }), - "RFactor axes less than 0 or >= ndims."); + return noReductionDomain; +} - TORCH_CHECK( - !hasRFactor(), "Cannot call rfactor on the same tensor domain twice."); - - std::set axes_set(axes.begin(), axes.end()); - - bool rfactor_found = false; - bool reduction_found = false; - for (decltype(nDims()) i{0}; i < nDims(); i++) { - if (axis(i)->isReduction()) { - if (axes_set.find(i) != axes_set.end()) - rfactor_found = true; - else - reduction_found = true; - } - } +std::vector TensorDomain::noBroadcasts( + const std::vector& td) { + size_t size_out = 0; + for (auto id : td) + if (!id->isBroadcast()) + size_out++; + std::vector noBroadcastDomain(size_out); - TORCH_CHECK( - rfactor_found && reduction_found, - "Invalid rfactor found, rfactor must be provided at least one reduction axis, but not all reduction axes."); + int it = 0; + for (auto id : td) + if (!id->isBroadcast()) + noBroadcastDomain[it++] = id; - return std::pair{ - TransformRFactor::runReplay(this, axes), - TransformRFactor::runReplay2(this, axes)}; + return noBroadcastDomain; +} + +bool TensorDomain::hasBroadcast(const std::vector& td) { + for (auto id : td) + if (id->isBroadcast()) + return true; + return false; +} +bool TensorDomain::hasReduction(const std::vector& td) { + for (auto id : td) + if (id->isReduction()) + return true; + return false; } -TensorDomain* TensorDomain::rootDomain() { - return TransformIter::getRoot(this); +// pair is in order where second is the consumer of first +std::pair TensorDomain::rFactor( + const std::vector& axes_) { + // std::vector axes(axes_.size()); + + // auto ndims = nDims(); + // std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) { + // return i < 0 ? i + ndims : i; + // }); + + // TORCH_CHECK( + // std::none_of( + // axes.begin(), + // axes.end(), + // [ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }), + // "RFactor axes less than 0 or >= ndims."); + + // TORCH_CHECK( + // !hasRFactor(), "Cannot call rfactor on the same tensor domain twice."); + + // std::set axes_set(axes.begin(), axes.end()); + + // bool rfactor_found = false; + // bool reduction_found = false; + // for (decltype(nDims()) i{0}; i < nDims(); i++) { + // if (axis(i)->isReduction()) { + // if (axes_set.find(i) != axes_set.end()) + // rfactor_found = true; + // else + // reduction_found = true; + // } + // } + + // TORCH_CHECK( + // rfactor_found && reduction_found, + // "Invalid rfactor found, rfactor must be provided at least one reduction + // axis, but not all reduction axes."); + + // return std::pair{ + // TransformRFactor::runReplay(this, axes), + // TransformRFactor::runReplay2(this, axes)}; + + TORCH_INTERNAL_ASSERT(false, "NIY."); } Split::Split( diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index e33bdae37996d..5c156e63376f8 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -269,14 +269,14 @@ void GPULower::replaceSizes() { for (TensorView* tv : orig_inp_out) { // Replace the domain with one based on Ti.size[j] std::vector new_domain_iters; - TensorDomain* root_td = tv->getRootDomain(); + const std::vector& root_td = tv->getRootDomain(); - for (decltype(root_td->nDims()) i{0}; i < root_td->nDims(); i++) { + for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) { // Output sizes could have reduction axes, which isn't what gets output. - if (root_td->axis(i)->isReduction()) + if (root_td[i]->isReduction()) continue; - Val* orig_size = root_td->axis(i)->extent(); + Val* orig_size = root_td[i]->extent(); std::stringstream ss; ss << "T" << tv->name() << ".size[" << i << "]"; @@ -295,20 +295,20 @@ void GPULower::replaceSizes() { // Set domains to be based on symbolic sizes (i.e. Ti.size[...]) for (TensorView* tv : all_tvs) { std::vector new_domain_iters; - TensorDomain* root_td = tv->getRootDomain(); + const std::vector& root_td = tv->getRootDomain(); - for (decltype(root_td->nDims()) i{0}; i < root_td->nDims(); i++) { - Val* new_size = root_td->axis(i)->extent(); + for (decltype(root_td.size()) i{0}; i < root_td.size(); i++) { + Val* new_size = root_td[i]->extent(); if (size_map.find(new_size) != size_map.end()) new_size = size_map[new_size]; new_domain_iters.push_back(new IterDomain( - root_td->axis(i)->start(), + root_td[i]->start(), new_size, - root_td->axis(i)->parallel_method(), - root_td->axis(i)->isReduction(), - root_td->axis(i)->isRFactorProduct(), - root_td->axis(i)->isBroadcast())); + root_td[i]->parallel_method(), + root_td[i]->isReduction(), + root_td[i]->isRFactorProduct(), + root_td[i]->isBroadcast())); } TensorDomain* old_domain = tv->domain(); diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 891766e7267fb..9f071ae4c5f31 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -106,7 +106,7 @@ class IrParser { // Merge all dimensions because we're only supporting pointwise while (out->nDims() > 1) - out->merge(0); + out->merge(0, 1); // Split into 128 which will be bockDim.x out->split(0, nthreads); // Split by another 4 which will be our unroll factor diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index da13ec2276789..d12b14cfdc5b6 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -18,21 +18,21 @@ bool PredicateCompute::hasPredicates(const TensorIndex* ti) { std::vector PredicateCompute::computePredicates(const TensorIndex* ti) { const TensorView* tv = ti->view(); - TensorDomain* root = tv->getRootDomain(); + const std::vector& root = tv->getRootDomain(); std::vector preds; if (FusionGuard::getCurFusion()->origin(tv->domain()) == nullptr && tv->nDims() == ti->nDims()) return preds; - TORCH_INTERNAL_ASSERT(root->nDims() == ti->nDims()); + TORCH_INTERNAL_ASSERT(root.size() == ti->nDims()); for (decltype(ti->nDims()) i{0}; i < ti->nDims(); i++) // I believe the second part of this check is redundant, but it doesn't // hurt. if (FusionGuard::getCurFusion()->origin(ti->index(i)) != nullptr && - !root->axis(i)->isBroadcast()) { - Val* pred = lt(ti->index(i), root->axis(i)->extent()); + !root[i]->isBroadcast()) { + Val* pred = lt(ti->index(i), root[i]->extent()); TORCH_INTERNAL_ASSERT( pred->getValType().value() == ValType::Scalar && pred->getDataType().value() == DataType::Bool); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index f6a6a06fcaab5..d3116f7674f57 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -50,8 +50,8 @@ bool TensorView::hasBroadcast() const { return domain()->hasBroadcast(); } -TensorDomain* TensorView::getRootDomain() const { - return TransformIter::getRoot(this->domain()); +const std::vector& TensorView::getRootDomain() const { + return domain()->rootDomain(); }; std::vector::size_type TensorView::nDims() const { @@ -349,7 +349,7 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { return this; } -TensorView* TensorView::split(int axis, int factor) { +TensorView* TensorView::split(int axis, unsigned int factor) { if (axis < 0) axis += domain()->nDims(); @@ -362,30 +362,36 @@ TensorView* TensorView::split(int axis, int factor) { " thisComputeAtAxis = ", getThisComputeAtAxis()); - setDomain(domain()->split(axis, factor)); + domain()->split(axis, factor); return this; } // Merge "axis" and "axis+1" into 1 dimension -TensorView* TensorView::merge(int axis) { - if (axis < 0) - axis += domain()->nDims(); +TensorView* TensorView::merge(int axis_o, int axis_i) { + if (axis_o < 0) + axis_o += domain()->nDims(); + + if (axis_i < 0) + axis_i += domain()->nDims(); if (getComputeAtView() != nullptr) - if (axis + 1 < (int)getThisComputeAtAxis()) + if (axis_o + 1 < (int)getThisComputeAtAxis() || + axis_i + 1 < (int)getThisComputeAtAxis()) TORCH_CHECK( false, - "Cannot merge axis within compute at range. Axis = ", - axis, - " thisComputeAtAxis = ", + "Cannot merge axis within compute at range. Either axis ", + axis_o, + " or ", + axis_i, + " are within thisComputeAtAxis = ", getThisComputeAtAxis()); - setDomain(domain()->merge(axis)); + domain()->merge(axis_o, axis_i); return this; } TensorView* TensorView::reorder(const std::unordered_map& old2new_) { - setDomain(domain()->reorder(old2new_)); + domain()->reorder(old2new_); return this; } From 68f251af45e9edc7199a24a4de1b892942feba85 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 24 May 2020 15:01:56 -0400 Subject: [PATCH 5/8] Fix reorder test, add clone to TensorDomain. --- test/cpp/jit/test_gpu.cpp | 14 +++++++------- torch/csrc/jit/codegen/cuda/ir_internal_nodes.h | 2 ++ torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 12 ++++++++++-- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 2 +- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index eeb65c5d50910..e9013284fa06f 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -425,25 +425,25 @@ void testGPU_FusionTVReorder() { TensorView* ref = dummyTensor->clone(); TensorView* tv = dummyTensor->clone(); - TensorView* s_leftl = tv->reorder(shift_left); + TensorView* s_left1 = tv->reorder(shift_left); for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i) == s_leftl->axis(i - 1)); + TORCH_CHECK(ref->axis(i)->sameAs(s_left1->axis(i - 1))); tv = dummyTensor->clone(); TensorView* s_left2 = tv->reorder(shift_left); for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i) == s_left2->axis(i - 1)); + TORCH_CHECK(ref->axis(i)->sameAs(s_left2->axis(i - 1))); tv = dummyTensor->clone(); TensorView* s_right = tv->reorder(shift_right); for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i - 1) == s_right->axis(i)); + TORCH_CHECK(ref->axis(i - 1)->sameAs(s_right->axis(i))); tv = dummyTensor->clone(); TensorView* rswap = tv->reorder(swap); - TORCH_CHECK(ref->axis(0) == rswap->axis(2)); - TORCH_CHECK(ref->axis(2) == rswap->axis(0)); - TORCH_CHECK(ref->axis(1) == rswap->axis(1)); + TORCH_CHECK(ref->axis(0)->sameAs(rswap->axis(2))); + TORCH_CHECK(ref->axis(2)->sameAs(rswap->axis(0))); + TORCH_CHECK(ref->axis(1)->sameAs(rswap->axis(1))); } void testGPU_FusionEquality() { diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 28845f2e5bfc0..8dcbde218b7f3 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -364,6 +364,8 @@ struct TORCH_CUDA_API TensorDomain : public Val { return domain_.size(); } + TensorDomain* clone() const; + bool sameAs(const TensorDomain* const other) const; static bool sameAs( diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index a0d17bc25883f..84016aece3086 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -342,6 +342,14 @@ bool TensorDomain::sameAs(const TensorDomain* const other) const { return true; } +TensorDomain* TensorDomain::clone() const { + std::vector domain(domain_.size()); + size_t i = 0; + for (auto dom : domain_) + domain[i++] = dom->clone(); + return new TensorDomain(domain); +} + bool TensorDomain::sameAs( const std::vector& lhs, const std::vector& rhs) { @@ -410,8 +418,8 @@ void TensorDomain::merge(int axis_o, int axis_i) { axis_i += nDims(); TORCH_CHECK( - axis_o >= 0 && (unsigned int)(axis_o + 1) < nDims() && axis_i >= 0 && - (unsigned int)(axis_i + 1) < nDims(), + axis_o >= 0 && (unsigned int)axis_o < nDims() && axis_i >= 0 && + (unsigned int)axis_i < nDims(), "Invalid merge detected, either one or both axes are outside of TensorView's range."); TORCH_CHECK( diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index d3116f7674f57..9e69b5ba4ca2f 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -35,7 +35,7 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) } TensorView* TensorView::clone() const { - TensorView* new_view = new TensorView(domain_, getDataType().value()); + TensorView* new_view = new TensorView(domain()->clone(), getDataType().value()); new_view->setComputeAt(compute_at_view_, (int)relative_compute_at_axis_); new_view->memory_type_ = getMemoryType(); From 8660f7b27344509f878e14e9a1c7896746ff173b Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 26 May 2020 13:01:14 -0400 Subject: [PATCH 6/8] Re-write replayPasC and replayCasP. Hook computeAt back up, validate on pointwise ops. --- .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 5 + torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 34 + torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 115 ++- torch/csrc/jit/codegen/cuda/iter_visitor.h | 6 +- torch/csrc/jit/codegen/cuda/lower2device.cpp | 5 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 94 +- .../csrc/jit/codegen/cuda/transform_iter.cpp | 812 +----------------- torch/csrc/jit/codegen/cuda/transform_iter.h | 177 ++++ .../jit/codegen/cuda/transform_replay.cpp | 658 +++++++++----- .../csrc/jit/codegen/cuda/transform_replay.h | 17 +- 10 files changed, 830 insertions(+), 1093 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 8dcbde218b7f3..94d508b9493ab 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -359,6 +359,9 @@ struct TORCH_CUDA_API TensorDomain : public Val { TensorDomain& operator=(TensorDomain&& other) = delete; TensorDomain(std::vector _domain); + TensorDomain( + std::vector _root_domain, + std::vector _domain); std::vector::size_type nDims() const { return domain_.size(); @@ -400,6 +403,8 @@ struct TORCH_CUDA_API TensorDomain : public Val { // uint. IterDomain* axis(int i) const; + size_t posOf(IterDomain* id) const; + // Split "axis" into 2 axes where the inner axes is size of "factor" // and outer axis is size axis.size() / factor void split(int axis, unsigned int factor); diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 84016aece3086..dcb5421a7cee6 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -331,6 +331,30 @@ TensorDomain::TensorDomain(std::vector _domain) resetDomains(); } +TensorDomain::TensorDomain( + std::vector _root_domain, + std::vector _domain) + : Val(ValType::TensorDomain), + root_domain_(std::move(_root_domain)), + domain_(std::move(_domain)) { + std::vector domain_vals(domain_.begin(), domain_.end()); + auto inps = IterVisitor::getInputsTo(domain_vals); + + // Validate that the root domain consists of all inputs to _domain + // Uncertain if this will hold for RFactor + + std::set root_vals(root_domain_.begin(), root_domain_.end()); + std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { + TORCH_INTERNAL_ASSERT( + root_vals.find(inp) != root_vals.end(), + "Invalid tensor domain, ", + inp, + " is an input of domain, but it is not found in the root domain."); + }); + + resetDomains(); +} + bool TensorDomain::sameAs(const TensorDomain* const other) const { if (nDims() != other->nDims()) return false; @@ -391,6 +415,16 @@ IterDomain* TensorDomain::axis(int i) const { return domain_[i]; } +size_t TensorDomain::posOf(IterDomain* id) const { + size_t i = 0; + while (i < domain_.size()) { + if (domain_[i] == id) + return i; + i++; + } + TORCH_CHECK(false, "Provided id is not part of this domain."); +} + // Split "axis" into 2 axes where the inner axes is size of "factor" // and outer axis is size axis.extent() / factor void TensorDomain::split(int axis_, unsigned int factor) { diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index f4b40ec5064a0..1bd1d7e0c3fa5 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -103,9 +103,10 @@ void IterVisitor::traverse_( TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); if (from_outputs_only) { - auto term_outs = DependencyCheck::getTerminatingOutputs(fusion); - if (!term_outs.empty()) - traverseFrom(fusion, term_outs, traverse_all_paths); + auto term_outs = IterVisitor::getTerminatingOutputs(fusion); + std::vector term_val_outs(term_outs.begin(), term_outs.end()); + if (!term_val_outs.empty()) + traverseFrom(fusion, term_val_outs, traverse_all_paths); return; } @@ -133,6 +134,74 @@ void IterVisitor::traverseAllPaths( traverse_(fusion, from_outputs_only, breadth_first, true); } +namespace { + +// Expr sort will take a fusion and return a topologically sorted list of +// expressions. +struct Exprs : public IterVisitor { + private: + std::vector exprs; + + void handle(Expr* expr) override { + exprs.push_back(expr); + } + + public: + static std::vector getExprs( + Fusion* fusion, + const std::vector& from) { + Exprs ex; + ex.traverseFrom(fusion, from, false); + return ex.exprs; + } +}; + +// Expr sort will take a fusion and return a topologically sorted list of +// expressions. +struct Inputs : public IterVisitor { + private: + std::set inputs; + + void handle(Val* val) { + if (val->getOrigin() == nullptr) + inputs.emplace(val); + } + + public: + static std::set getInputs(const std::vector& of) { + if (of.empty()) + return std::set(); + Inputs inps; + inps.traverseFrom(of[0]->fusion(), of); + return inps.inputs; + } +}; +} // namespace + +std::set IterVisitor::getTerminatingOutputs(Fusion* const fusion) { + FusionGuard fg(fusion); + + std::set used_vals; + for (auto expr : Exprs::getExprs( + fusion, + std::vector( + fusion->outputs().begin(), fusion->outputs().end()))) { + for (auto inp : expr->inputs()) + used_vals.emplace(inp); + } + + std::set terminating_outputs; + for (auto out : fusion->outputs()) + if (used_vals.find(out) == used_vals.end()) + terminating_outputs.emplace(out); + + return terminating_outputs; +} + +std::set IterVisitor::getInputsTo(const std::vector& vals) { + return Inputs::getInputs(vals); +} + /* DEPENDENCY CHECKING */ namespace { @@ -212,26 +281,6 @@ struct DependencyChains : public IterVisitor { } }; -// Expr sort will take a fusion and return a topologically sorted list of -// expressions. -struct Exprs : public IterVisitor { - private: - std::vector exprs; - - void handle(Expr* expr) override { - exprs.push_back(expr); - } - - public: - static std::vector getExprs( - Fusion* fusion, - const std::vector& from) { - Exprs ex; - ex.traverseFrom(fusion, from, false); - return ex.exprs; - } -}; - } // namespace bool DependencyCheck::isDependencyOf(Val* dependency, Val* of) { @@ -255,26 +304,6 @@ std::deque> DependencyCheck::getAllDependencyChainsTo( return DependencyChains::getDependencyChainsTo(dependency); } -std::vector DependencyCheck::getTerminatingOutputs(Fusion* const fusion) { - FusionGuard fg(fusion); - - std::set used_vals; - for (auto expr : Exprs::getExprs( - fusion, - std::vector( - fusion->outputs().begin(), fusion->outputs().end()))) { - for (auto inp : expr->inputs()) - used_vals.emplace(inp); - } - - std::vector terminating_outputs; - for (auto out : fusion->outputs()) - if (used_vals.find(out) == used_vals.end()) - terminating_outputs.push_back(out); - - return terminating_outputs; -} - } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 04935bfb87714..24390721bc222 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -106,6 +106,10 @@ struct TORCH_CUDA_API IterVisitor : public OptOutDispatch { Fusion* const fusion, bool from_outputs_only = false, bool breadth_first = false); + + static std::set getTerminatingOutputs(Fusion* const); + + static std::set getInputsTo(const std::vector& vals); }; struct TORCH_CUDA_API DependencyCheck { @@ -128,8 +132,6 @@ struct TORCH_CUDA_API DependencyCheck { // paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency". // Returns an empty deque if there are no uses of dependency found. static std::deque> getAllDependencyChainsTo(Val* dependency); - - static std::vector getTerminatingOutputs(Fusion* const); }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 5c156e63376f8..18e0d4b8b7532 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -319,8 +319,9 @@ void GPULower::replaceSizes() { // IterDomain/TensorDomain are vals. std::vector axis_map(new_domain->nDims()); std::iota(axis_map.begin(), axis_map.end(), 0); - new_domain = TransformIter::replaySelf( - new_domain, TransformIter::getHistory(old_domain), axis_map); + TORCH_INTERNAL_ASSERT(false, "NIY."); + // new_domain = TransformIter::replaySelf( + // new_domain, TransformIter::getHistory(old_domain), axis_map); TORCH_INTERNAL_ASSERT( old_domain->nDims() == new_domain->nDims(), diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 9e69b5ba4ca2f..0c707b8d4d454 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -35,7 +35,8 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) } TensorView* TensorView::clone() const { - TensorView* new_view = new TensorView(domain()->clone(), getDataType().value()); + TensorView* new_view = + new TensorView(domain()->clone(), getDataType().value()); new_view->setComputeAt(compute_at_view_, (int)relative_compute_at_axis_); new_view->memory_type_ = getMemoryType(); @@ -165,20 +166,24 @@ void TensorView::setThisComputeAtAxis() { } // Actually applies transformation -void TensorView::computeAt_impl(TensorView* consumer, int axis) { +void TensorView::computeAt_impl( + TensorView* consumer, + int consumer_compute_at_axis) { // Reset view otherwise will conflict with replay. clearComputeAt(); // replay this as consumer / producer as consumer - TransformReplay::replayPasC(this, consumer, axis); - setComputeAt(consumer, axis); + TransformReplay::replayPasC(this, consumer, consumer_compute_at_axis); + setComputeAt(consumer, consumer_compute_at_axis); } // Actually applies transformation -void TensorView::forwardComputeAt_impl(TensorView* producer, int axis) { +void TensorView::forwardComputeAt_impl( + TensorView* producer, + int producer_compute_at_axis) { // Reset view otherwise will conflict with replay. producer->clearComputeAt(); - TransformReplay::replayCasP(this, producer, axis); - producer->setComputeAt(this, axis); + TransformReplay::replayCasP(this, producer, producer_compute_at_axis); + producer->setComputeAt(this, producer_compute_at_axis); } namespace { @@ -235,13 +240,16 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { axis >= 0 && (unsigned int)axis < consumer->nDims() + 1, "Compute at called on an axis outside valid range."); - // If not direct relationship follow dependency chain. + // If not direct relationship follow dependency chain from consumer to + // producer. auto dep_chains = DependencyCheck::getAllDependencyChains(this, consumer); std::deque dep_chain; if (!dep_chains.empty()) dep_chain = dep_chains.front(); + // Make sure there is a dependency chain, if not it's an invalid computeAt. + // We could do indirect computeAts, but it's not supported at this time. TORCH_CHECK( !dep_chain.empty(), "Compute At expects ", @@ -250,11 +258,15 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { consumer, ", however it is not."); + // Validate dependency chain returned as expected TORCH_INTERNAL_ASSERT( dep_chain.back() == consumer && dep_chain[0] == this, "Error computing dependency chain."); - // Replay from consumer to producer + // Start the replay going from consumer, through the dependency chain to + // producer. After this section, producer should look like consumer, and there + // should be a computeAt chain going from producer to consumer. Proper + // computeAts are setup, though they will be over-written in a later stage. while (dep_chain.size() > 1) { Val* consumer_val = dep_chain.back(); dep_chain.pop_back(); @@ -267,21 +279,28 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { TensorView* running_consumer = static_cast(consumer_val); TensorView* running_producer = static_cast(producer_val); - running_producer->computeAt_impl(running_consumer, axis); + // Axis is relative to consumer, however as we propagate computeAt, it may + // move. This is why we have TensorView->getThisComputeAtAxis() which + // returns where in a TensorView does the computeAt (relative to consumer) + // line up. Mismatch is due to broadcast. + int compute_at_axis = axis; + if (running_consumer != consumer) + compute_at_axis = running_consumer->getThisComputeAtAxis(); + running_producer->computeAt_impl(running_consumer, compute_at_axis); } /* * Compute At has now worked from consumer to producer, transforming producer * to match computeAt selected in consumer We now need to work from producer * up to its consumers (including indirect consumption) so their use also - * matches. If we can find a TV that contains all uses of producer, we can - * terminate this propagation there. If not, we need to propagate all the way - * to outputs. - * - * First we'll look for that terminating point. + * matches. If we can find a TV that contains all uses of producer (common + * consumer), we can terminate this propagation there. If not, we need to + * propagate all the way to outputs. */ - // Grab all uses of producer + // Start looking for a common consumer of producer + + // Grab all uses of producer in fusion auto val_all_consumer_chains = DependencyCheck::getAllDependencyChainsTo(this); @@ -291,15 +310,18 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { all_consumer_chains.push_back( tv_iterable>(val_dep_chain)); + // Set arith to find a common consumer, start with first use chain of producer std::set common_consumers( all_consumer_chains.front().begin(), all_consumer_chains.front().end()); + // Run through all use chains of producer, and intersect them for (auto dep_chain : all_consumer_chains) common_consumers = set_intersection( common_consumers, std::set(dep_chain.begin(), dep_chain.end())); - // Remove all TVs between producer and consumer + // Remove all TVs between producer and consumer as we don't want a common + // consumer placed logically before consumer provided in computeAt for (const auto& dep_chain : dep_chains) { auto tv_chain = tv_iterable>(dep_chain); for (auto tv : tv_chain) { @@ -308,7 +330,7 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { } } - // Grab the first (topologically) common consumer + // If there is a common consumer, grab the first one (topologically) TensorView* common_consumer = nullptr; if (!common_consumers.empty()) { for (TensorView* tv : all_consumer_chains.front()) @@ -318,11 +340,26 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { } } - // Forward compute at through all consumers until common_consumer if there is - // one + // Forward propagate the transformationthrough all use chains until + // common_consumer if there is one otherwise until we hit all output TVs std::set output_set; - std::vector ordered_outputs; + // computeAt axis in outputs don't necessarily match up, make sure to keep the + // relative computeAt position in each output + std::vector> ordered_outputs; for (auto dep_chain : all_consumer_chains) { + // All dep chains start with this. + TORCH_INTERNAL_ASSERT( + dep_chain.front() == this, + "Invalid dependency chain found during computeAt, ", + dep_chain.front(), + " should be ", + this); + TORCH_INTERNAL_ASSERT( + this->hasComputeAt(), + "Error detected during computeAt, ", + this, + ", should have a computeAt set at this point even though we will over-write it."); + int running_producer_compute_at = this->getThisComputeAtAxis(); while (dep_chain.size() > 1) { TensorView* running_producer = dep_chain.front(); dep_chain.pop_front(); @@ -330,12 +367,19 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { if (running_producer == common_consumer) break; + // Axis is relative to consumer, and may not necessarily apply to all + // intermediate steps. Fortunately producer is guarenteed to have a valid + // computeAt set, so we can use the compute at axis relative to producer. + running_consumer->forwardComputeAt_impl( + running_producer, running_producer_compute_at); + running_producer_compute_at = running_producer->getThisComputeAtAxis(); + int consumer_compute_at = running_producer->getRelativeComputeAtAxis(); - running_consumer->forwardComputeAt_impl(running_producer, axis); if (dep_chain.size() == 1) { // last one if (output_set.find(running_consumer) == output_set.end()) { output_set.emplace(running_consumer); - ordered_outputs.push_back(running_consumer); + ordered_outputs.push_back(std::pair( + running_consumer, consumer_compute_at)); } } } @@ -344,7 +388,9 @@ TensorView* TensorView::computeAt(TensorView* consumer, int axis) { if (!ordered_outputs.empty()) for (auto it = ordered_outputs.begin(); it + 1 != ordered_outputs.end(); it++) - (*it)->computeAt_impl((*(it + 1)), axis); + (*it).first->computeAt_impl( + (*(it + 1)).first, + (*(it + 1)).second); // use recorded position, not axis. return this; } diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 6d00a71bfed32..94229ed41c73c 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -3,810 +3,10 @@ // #include // #include -// namespace torch { -// namespace jit { -// namespace fuser { +#include -// TensorDomain* TransformIter::replayBackward(Split* split, TensorDomain* td) { -// return split->in(); -// } - -// TensorDomain* TransformIter::replayBackward(Merge* merge, TensorDomain* td) { -// return merge->in(); -// } - -// TensorDomain* TransformIter::replayBackward(Expr* expr, TensorDomain* td) { -// TORCH_INTERNAL_ASSERT( -// expr->isExpr(), -// "Dispatch in transform iteration is expecting Exprs only."); -// switch (*(expr->getExprType())) { -// case (ExprType::Split): -// return replayBackward(static_cast(expr), td); -// case (ExprType::Merge): -// return replayBackward(static_cast(expr), td); -// default: -// TORCH_INTERNAL_ASSERT( -// false, "Could not detect expr type in replayBackward."); -// } -// } - -// std::vector TransformIter::getHistory(TensorDomain* td) { -// std::vector ops; -// TensorDomain* root = td; // backward running td -// Fusion* fusion = FusionGuard::getCurFusion(); - -// // Get my origin -// Expr* orig = fusion->origin(root); -// std::set visited_exprs; - -// // If I'm not back to the original td -// while (orig != nullptr) { -// if (visited_exprs.find(orig) != visited_exprs.end()) -// TORCH_INTERNAL_ASSERT( -// false, -// "TransformReplay::runBackward is not traversing a correct -// history."); -// ops.push_back(orig); -// visited_exprs.emplace(orig); -// TensorDomain* previous_td = nullptr; -// // Check inputs of this operation, make sure there isn't more than one TD -// // I can only record operations that only take this TD as an input. -// for (Val* inp : orig->inputs()) -// if (inp->getValType() == ValType::TensorDomain) { -// if (previous_td != nullptr) -// TORCH_INTERNAL_ASSERT( -// false, -// "TransformReplay::runBackward could not decifer transform -// history of a TensorDomain."); - -// // Traverse back -// root = static_cast(inp); -// orig = fusion->origin(root); -// } -// } -// return std::vector(ops.rbegin(), ops.rend()); -// } - -// TensorDomain* TransformIter::runBackward(TensorDomain* td) { -// std::vector ops = getHistory(td); - -// // We want to iterate backwards, reverse history. -// ops = std::vector(ops.rbegin(), ops.rend()); - -// TensorDomain* running_td = td; -// for (Expr* op : ops) -// running_td = replayBackward(op, running_td); - -// return running_td; -// } - -// TensorDomain* TransformIter::replay(Split* expr, TensorDomain* td) { -// return td->split( -// expr->axis(), static_cast(expr->factor())->value().value()); -// } - -// TensorDomain* TransformIter::replay(Merge* expr, TensorDomain* td) { -// return td->merge(expr->axis()); -// } - -// TensorDomain* TransformIter::replay(Expr* expr, TensorDomain* td) { -// TORCH_INTERNAL_ASSERT(expr->isExpr()); -// switch (*(expr->getExprType())) { -// case (ExprType::Split): -// return replay(static_cast(expr), td); -// case (ExprType::Merge): -// return replay(static_cast(expr), td); -// default: -// TORCH_INTERNAL_ASSERT(false, "Could not detect expr type in replay."); -// } -// } - -// TensorDomain* TransformIter::runReplay( -// TensorDomain* td, -// const std::vector& history) { -// for (Expr* op : history) -// td = TransformIter::replay(op, td); -// return td; -// } - -// namespace { - -// void validate_axis_map(int nDims, const std::vector& axis_map) { -// TORCH_INTERNAL_ASSERT( -// axis_map.size() == (unsigned int)nDims, -// "Invalid axis map in replay transform. NDims doesn't match."); - -// TORCH_INTERNAL_ASSERT( -// !std::any_of( -// axis_map.begin(), -// axis_map.end(), -// [nDims](int i) { return i < -1 || i >= nDims; }), -// "Invalid axis map in replay transform, map goes outside domains of -// provided TensorDomain."); -// } - -// void validate_history_entry(Expr* expr, int nDims) { -// TORCH_INTERNAL_ASSERT( -// expr->input(0)->getValType().value() == ValType::TensorDomain && -// static_cast(expr->input(0))->nDims() == -// (unsigned int)nDims, -// "Invalid history, or invalid axis_map in TransformIter."); -// } - -// struct Influence : public TransformIter { -// private: -// // BACKWARD INFLUENCE - -// TensorDomain* replayBackward(Split* split, TensorDomain* td) override { -// int axis = split->axis(); - -// TORCH_INTERNAL_ASSERT( -// (unsigned int)(axis + 1) < influence.size(), -// "Error during replay backwards, td/influence size mismatch."); -// influence[axis] = influence[axis] | influence[axis + 1]; -// influence.erase(influence.begin() + axis + 1); - -// return split->in(); -// } - -// TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { -// int axis = merge->axis(); -// TORCH_INTERNAL_ASSERT( -// (unsigned int)axis < influence.size(), -// "Error during replay backwards, td/influence size mismatch."); -// influence.insert(influence.begin() + axis + 1, influence[axis]); - -// return merge->in(); -// } - -// // FORWARD INFLUENCE - -// TensorDomain* replay(Split* split, TensorDomain* td) override { -// int axis = split->axis(); -// TORCH_INTERNAL_ASSERT( -// (unsigned int)axis < influence.size(), -// "Error during replay, td/influence size mismatch."); -// influence.insert(influence.begin() + axis + 1, influence[axis]); -// return nullptr; -// } - -// TensorDomain* replay(Merge* merge, TensorDomain* td) override { -// int axis = merge->axis(); -// TORCH_INTERNAL_ASSERT( -// axis >= 0 && (unsigned int)(axis + 1) < influence.size(), -// "Error during replay, td/influence size mismatch."); -// influence[axis] = influence[axis] | influence[axis + 1]; -// influence.erase(influence.begin() + axis + 1); -// return nullptr; -// } - -// // INTERFACE - -// std::vector influence; - -// Influence(std::vector td_influence) -// : influence(std::move(td_influence)) {} - -// using TransformIter::replayBackward; -// using TransformIter::runReplay; - -// public: -// static std::vector computeBackward( -// const std::vector& history, -// const std::vector& td_influence) { -// if (history.empty()) -// return td_influence; - -// Val* last_val = history[history.size() - 1]->output(0); -// TORCH_INTERNAL_ASSERT( -// last_val->getValType().value() == ValType::TensorDomain && -// static_cast(last_val)->nDims() == -// td_influence.size(), -// "Tried to compute influence, but recieved an influence vector that -// does not match the expected size."); - -// Influence inf(td_influence); -// std::vector ops(history.rbegin(), history.rend()); -// for (Expr* op : ops) -// inf.replayBackward(op, nullptr); -// return inf.influence; -// } - -// static std::vector computeForward( -// const std::vector& history, -// const std::vector& td_influence) { -// if (history.empty()) -// return td_influence; - -// TORCH_INTERNAL_ASSERT( -// history[0]->input(0)->getValType().value() == ValType::TensorDomain -// && -// static_cast(history[0]->input(0))->nDims() == -// td_influence.size(), -// "Tried to compute influence, but recieved an influence vector that -// does not match the expected size."); -// Influence inf(td_influence); -// inf.runReplay(nullptr, history); -// return inf.influence; -// } - -// }; // struct Influence - -// struct Replay : public TransformIter { -// /* -// * Replay functions, takes a TensorDomain and steps through the operations -// in -// * "record" based on influence axes. Will also update influence and -// propagate -// * it forward. -// */ -// TensorDomain* replay(Split* split, TensorDomain* td) override { -// int saxis = split->axis(); - -// TORCH_INTERNAL_ASSERT( -// saxis >= 0 && (unsigned int)saxis < axis_map.size(), -// "TransformReplay tried to modify an axis out of range, recieved ", -// saxis, -// " but this value should be >=0 and <", -// axis_map.size()); - -// // Axis relative to td -// int axis = axis_map[saxis]; - -// if (axis == -1) { -// // don't modify path, we need an extra axis as there would have been -// one -// // there, but we shouldn't modify it. -// axis_map.insert(axis_map.begin() + saxis + 1, -1); -// return td; -// } - -// // Move indices up as we now have an extra axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i + 1 : i; -// }); - -// // Insert new axis in map -// axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); - -// TORCH_INTERNAL_ASSERT( -// split->factor()->isConst(), -// "Cannot replay split as it's not based on a const value."); -// td = td->split(axis, split->factor()->value().value()); - -// return td; -// } - -// TensorDomain* replay(Merge* merge, TensorDomain* td) override { -// int maxis = merge->axis(); - -// TORCH_INTERNAL_ASSERT( -// maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), -// "TransformReplay tried to modify an axis out of range, recieved ", -// maxis, -// " but this value should be >= 0 and < axis_map.size()"); - -// // Get axis relative to what we actually have in td. -// int axis = axis_map[maxis]; -// int axis_p_1 = axis_map[maxis + 1]; -// // If either dim is not to be touch, set both not to be touched -// axis = axis_p_1 == -1 ? -1 : axis; -// axis_map[maxis] = axis; - -// // Remove axis from axis_map as in original transformations it didn't -// exist axis_map.erase(axis_map.begin() + maxis + 1); - -// // Don't modify: -// if (axis == -1) -// return td; - -// // Move indices down as we're removing an axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i - 1 : i; -// }); - -// return td->merge(axis); -// } - -// std::vector axis_map; -// Replay(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} - -// public: -// // Replays history provided on td, axis_map is the mapping from td axes to -// // those expected in history, if an axis shouldn't be transformed, it needs -// to -// // be marked as -1 in the axis_map -// static TensorDomain* replay( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// if (history.empty()) -// return td; - -// Replay r(axis_map); -// return r.runReplay(td, history); -// } - -// }; // struct Replay - -// struct ReplaySelf : public TransformIter { -// /* -// * Replay functions, takes a TensorDomain and steps through its own history -// * and reapplies it based on influence axes. Will replay rfactor axes -// * correctly as well. -// */ -// TensorDomain* replay(Split* split, TensorDomain* td) override { -// int saxis = split->axis(); - -// TORCH_INTERNAL_ASSERT( -// saxis >= 0 && (unsigned int)saxis < axis_map.size(), -// "TransformReplay tried to modify an axis out of range, recieved ", -// saxis, -// " but this value should be >=0 and <", -// axis_map.size()); - -// // Axis relative to td -// int axis = axis_map[saxis]; - -// if (axis == -1) { -// // don't modify path, we need an extra axis as there would have been -// one -// // there, but we shouldn't modify it. -// axis_map.insert(axis_map.begin() + saxis + 1, -1); -// return td; -// } - -// // Move indices up as we now have an extra axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i + 1 : i; -// }); - -// // Insert new axis in map -// axis_map.insert(axis_map.begin() + saxis + 1, axis + 1); - -// TORCH_INTERNAL_ASSERT( -// split->factor()->isConst(), -// "Cannot replay split as it's not based on a const value."); - -// // Create new domain reflecting split -// std::vector new_domain; -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { -// if ((int)i == axis) { -// // We want to support cases where our root domain has changed sizes -// // this happens in lowering when we replace sizes with runtime look -// ups IterDomain* td_axis = td->axis(axis); IterDomain* saxis_1 = -// split->out()->axis(saxis); IterDomain* saxis_2 = -// split->out()->axis(saxis + 1); -// // manually replay split domains using td extent, otherwise matching -// // split axes params. -// TORCH_CHECK( -// td_axis->start()->isZeroInt(), -// "Splitting IterDomains with starting values that aren't 0, is not -// supported at this time."); - -// IterDomain* ido = new IterDomain( -// new Int(0), -// ceilDiv(td_axis->extent(), split->factor()), -// saxis_1->parallel_method(), -// saxis_1->isReduction(), -// saxis_1->isRFactorProduct(), -// saxis_1->isBroadcast()); -// new_domain.push_back(ido); - -// // inner loop IterDomain -// IterDomain* idi = new IterDomain( -// new Int(0), -// split->factor(), -// saxis_2->parallel_method(), -// saxis_2->isReduction(), -// saxis_2->isRFactorProduct(), -// saxis_1->isBroadcast()); -// new_domain.push_back(idi); -// } else { -// // Add in all other axes, these may not match the input td to the -// split. new_domain.push_back(td->axis(i)); -// } -// } - -// TensorDomain* replayed = new TensorDomain(new_domain); -// new Split(replayed, td, axis, split->factor()); -// return replayed; -// } - -// TensorDomain* replay(Merge* merge, TensorDomain* td) override { -// int maxis = merge->axis(); - -// TORCH_INTERNAL_ASSERT( -// maxis >= 0 && (unsigned int)(maxis + 1) < axis_map.size(), -// "TransformReplay tried to modify an axis out of range, recieved ", -// maxis, -// " but this value should be >= 0 and < axis_map.size()"); - -// // Get axis relative to what we actually have in td. -// int axis = axis_map[maxis]; -// int axis_p_1 = axis_map[maxis + 1]; -// // If either dim is not to be touch, set both not to be touched -// axis = axis_p_1 == -1 ? -1 : axis; -// axis_map[maxis] = axis; - -// // Remove axis from axis_map as in original transformations it didn't -// exist axis_map.erase(axis_map.begin() + maxis + 1); - -// // Don't modify: -// if (axis == -1) -// return td; - -// // Move indices down as we're removing an axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i - 1 : i; -// }); - -// // Create new domain reflecting post-merge -// std::vector new_domain; -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { -// if ((int)i == axis) { -// // We want to support cases where our root domain has changed sizes -// // this happens in lowering when we replace sizes with runtime look -// ups IterDomain* td_axis1 = td->axis(axis); IterDomain* td_axis2 = -// td->axis(axis_p_1); IterDomain* m_axis = merge->out()->axis(maxis); - -// TORCH_INTERNAL_ASSERT( -// td_axis1->start()->isZeroInt() && td_axis2->start()->isZeroInt(), -// "Splitting IterDomains with starting values that aren't 0, is not -// supported at this time."); - -// IterDomain* merged = new IterDomain( -// new Int(0), -// mul(td_axis1->extent(), td_axis2->extent()), -// m_axis->parallel_method(), -// m_axis->isReduction(), -// m_axis->isRFactorProduct(), -// m_axis->isBroadcast()); -// new_domain.push_back(merged); - -// } else if ((int)i != axis_p_1) { -// // Add in all other axes, these may not match the input td to the -// split. new_domain.push_back(td->axis(i)); -// } -// } - -// TensorDomain* replayed = new TensorDomain(new_domain); -// new Merge(replayed, td, axis); -// return replayed; -// } - -// std::vector axis_map; -// ReplaySelf(std::vector _axis_map) : axis_map(std::move(_axis_map)) {} - -// public: -// // Replays history provided on td, axis_map is the mapping from td axes to -// // those expected in history, if an axis shouldn't be transformed, it needs -// to -// // be marked as -1 in the axis_map -// static TensorDomain* replay( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// ReplaySelf r(axis_map); -// return r.runReplay(TransformIter::getRoot(td), history); -// } - -// }; // struct ReplaySelf - -// struct TransformBackward : public TransformIter { -// private: -// // axis_map goes from the transform position to the position in our -// modified -// // td. -// TensorDomain* replayBackward(Split* split, TensorDomain* td) override { -// int saxis = split->axis(); - -// TORCH_INTERNAL_ASSERT( -// saxis >= 0 && (unsigned int)saxis < axis_map.size(), -// "TransformBackward tried to modify an axis out of range, recieved ", -// saxis, -// " but this value should be >= 0 and < axis_map.size()"); - -// // Get axis relative to what we actually have in td. -// int axis = axis_map[saxis]; -// int axis_p_1 = axis_map[saxis + 1]; -// // If either dim is not to be touch, set both not to be touched -// axis = axis_p_1 == -1 ? -1 : axis; -// axis_map[saxis] = axis; - -// // Remove axis from axis_map as in original transformations it didn't -// exist axis_map.erase(axis_map.begin() + saxis + 1); - -// // Don't modify: -// if (axis == -1) -// return td; - -// // Move indices down as previously we didn't have the split axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i - 1 : i; -// }); - -// // Create new domain reflecting pre-split -// std::vector new_domain; -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { -// if ((int)i == axis) { -// IterDomain* orig_axis = split->in()->axis(saxis); -// // Insert pre-split axis, make sure isReduction matches what is -// expected new_domain.push_back(new IterDomain( -// orig_axis->start(), -// orig_axis->extent(), -// orig_axis->parallel_method(), -// td->axis(axis)->isReduction(), -// td->axis(axis)->isRFactorProduct(), -// td->axis(axis)->isBroadcast())); -// } else if ((int)i != axis_p_1) { -// // Add in all other axes, these may not match the input td to the -// split. new_domain.push_back(td->axis(i)); -// } -// } - -// TensorDomain* replayed_inp = new TensorDomain(new_domain); -// new Split(td, replayed_inp, axis, split->factor()); -// return replayed_inp; -// } - -// TensorDomain* replayBackward(Merge* merge, TensorDomain* td) override { -// /* -// * Remember axis_map goes from merge information -> how it's stored in td -// * When we're done we want axis_map to match the returned td before or -// not -// * before the merge depending on should_modify. -// */ - -// int maxis = merge->axis(); - -// TORCH_INTERNAL_ASSERT( -// maxis >= 0 && (unsigned int)maxis < axis_map.size(), -// "TransformBackward tried to modify an axis out of range, recieved ", -// maxis, -// " but this value should be >=0 and <", -// axis_map.size()); - -// if (axis_map[maxis] == -1) { -// // don't modify path, we need an extra axis as there was previously one -// // there, but we shouldn't modify it. -// axis_map.insert(axis_map.begin() + maxis + 1, -1); -// return td; -// } - -// // Recreate the merge, axis is relative to the td -// int axis = axis_map[maxis]; -// // Move indices up as previously we had an extra axis -// std::transform( -// axis_map.begin(), axis_map.end(), axis_map.begin(), [axis](int i) { -// return i > axis ? i + 1 : i; -// }); - -// // Insert pre-merged axis back into map -// axis_map.insert(axis_map.begin() + maxis + 1, axis_map[maxis] + 1); - -// // Create new domain reflecting pre-merge -// std::vector new_domain; -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { -// if ((int)i == axis) { -// IterDomain* td_axis = td->axis(axis); -// IterDomain* maxis_1 = merge->in()->axis(maxis); -// IterDomain* maxis_2 = merge->in()->axis(maxis + 1); -// new_domain.push_back(new IterDomain( -// maxis_1->start(), -// maxis_1->extent(), -// ParallelType::Serial, -// td_axis->isReduction(), -// td_axis->isRFactorProduct(), -// td_axis->isBroadcast())); -// new_domain.push_back(new IterDomain( -// maxis_2->start(), -// maxis_2->extent(), -// ParallelType::Serial, -// td_axis->isReduction(), -// td_axis->isRFactorProduct(), -// td_axis->isBroadcast())); -// } else { -// // Add in all other axes, these may not match the input td to the -// split. new_domain.push_back(td->axis(i)); -// } -// } - -// TensorDomain* replayed_inp = new TensorDomain(new_domain); -// new Merge(td, replayed_inp, axis); -// return replayed_inp; -// } - -// // Entry for backward influence propagation on td following record, history -// // should be present -> past as you go through the vector -// TensorDomain* replayBackward( -// TensorDomain* td, -// const std::vector& history) { -// TensorDomain* running_td = td; - -// std::vector rev_history(history.rbegin(), history.rend()); -// for (Expr* op : rev_history) -// running_td = TransformIter::replayBackward(op, running_td); -// return running_td; -// } - -// std::vector axis_map; - -// TransformBackward(std::vector _axis_map) -// : axis_map(std::move(_axis_map)){}; - -// public: -// static TensorDomain* replay( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// TransformBackward tb(axis_map); -// return tb.replayBackward(td, history); -// } -// }; - -// struct RFactorRoot : public TransformIter { -// bool found_non_rfactor_op = false; - -// TensorDomain* replay(Split* split, TensorDomain*) final { -// if (!split->in()->axis(split->axis())->isRFactorProduct()) -// found_non_rfactor_op = true; -// return split->out(); -// } - -// TensorDomain* replay(Merge* merge, TensorDomain*) final { -// if (!merge->in()->axis(merge->axis())->isRFactorProduct()) -// found_non_rfactor_op = true; -// return merge->out(); -// } - -// // Replay forward until we hit an operation that doesn't involve an rfactor -// // axis -// TensorDomain* runReplay(TensorDomain*, const std::vector& history) -// final { -// TORCH_INTERNAL_ASSERT( -// !history.empty(), "No history provided to find rfactor root -// domain."); - -// auto last_rfactor_op = history.begin(); -// auto running_op = history.begin(); - -// for (auto it = history.begin(); it != history.end(); it++) { -// TransformIter::replay(*it, nullptr); -// if (found_non_rfactor_op) -// break; -// running_op = it; -// } - -// // We need to make sure the rfactor root is ordered correctly. -// bool found_valid_rfactor_root = false; - -// Val* val; - -// while (!found_valid_rfactor_root && last_rfactor_op != history.end()) { -// // Try next val -// val = (*last_rfactor_op++)->output(0); -// TORCH_INTERNAL_ASSERT( -// val->getValType().value() == ValType::TensorDomain, -// "Invalid history to find rfactor root."); - -// TensorDomain* td = static_cast(val); -// bool found_rfactor_dim = false; -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) { -// if (found_rfactor_dim) { -// if (!td->axis(i)->isRFactorProduct()) -// break; -// } else { -// if (td->axis(i)->isRFactorProduct()) -// found_rfactor_dim = true; -// } -// if (i == td->nDims() - 1) -// found_valid_rfactor_root = true; -// } -// } -// TORCH_INTERNAL_ASSERT( -// found_valid_rfactor_root, "Could not find a valid rfactor root."); -// return static_cast(val); -// } - -// public: -// static TensorDomain* get(TensorDomain* td) { -// auto history = TransformIter::getHistory(td); -// if (history.empty()) -// return td; -// RFactorRoot rfr; -// return rfr.runReplay(nullptr, history); -// } -// }; - -// } // namespace - -// // API INTO TRANSFORM ITER - -// std::vector TransformIter::getRootInfluence( -// TensorDomain* td, -// const std::vector& td_influence) { -// return Influence::computeBackward( -// TransformIter::getHistory(td), td_influence); -// } - -// std::vector TransformIter::replayBackwardInfluence( -// const std::vector& history, -// const std::vector& td_influence) { -// return Influence::computeBackward(history, td_influence); -// } - -// std::vector TransformIter::replayInfluence( -// const std::vector& history, -// const std::vector& td_influence) { -// if (history.empty()) -// return td_influence; - -// return Influence::computeForward(history, td_influence); -// } - -// TensorDomain* TransformIter::replay( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// if (history.empty()) -// return td; -// if (std::none_of( -// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) -// return td; - -// validate_history_entry(history[0], axis_map.size()); -// return Replay::replay(td, history, axis_map); -// } - -// TensorDomain* TransformIter::replaySelf( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// if (std::none_of( -// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) -// return TransformIter::getRoot(td); - -// validate_axis_map(TransformIter::getRoot(td)->nDims(), axis_map); -// return ReplaySelf::replay(td, history, axis_map); -// } - -// TensorDomain* TransformIter::replayBackward( -// TensorDomain* td, -// const std::vector& history, -// const std::vector& axis_map) { -// if (history.empty()) -// return td; -// if (std::none_of( -// axis_map.begin(), axis_map.end(), [](int i) { return i > -1; })) -// return td; - -// TORCH_INTERNAL_ASSERT( -// history[history.size() - 1]->output(0)->getValType().value() == -// ValType::TensorDomain && -// static_cast(history[history.size() - 1]->output(0)) -// ->nDims() == axis_map.size(), -// "Invalid history, or invalid axis_map in TransformIter."); - -// return TransformBackward::replay(td, history, axis_map); -// } - -// TensorDomain* TransformIter::getRFactorRoot(TensorDomain* td) { -// auto td_root = TransformIter::getRoot(td); -// if (std::none_of( -// td_root->domain().begin(), -// td_root->domain().end(), -// [](IterDomain* id) { return id->isRFactorProduct(); })) -// return td_root; - -// return RFactorRoot::get(td); -// } - -// } // namespace fuser -// } // namespace jit -// } // namespace torch \ No newline at end of file +namespace torch { +namespace jit { +namespace fuser {} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index aec625077d65e..f6812bc96ae0f 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -11,6 +11,183 @@ namespace torch { namespace jit { namespace fuser { +namespace { + +struct id_int_lt { + bool operator()( + const std::pair& first, + const std::pair& second) const { + return first.second < second.second; + } +}; + +} // namespace + +struct TORCH_CUDA_API ReplayTransformations : public IterVisitor { + private: + std::vector target_domain_; + std::unordered_map id_map_; + std::unordered_map leaf_ids_; + std::vector leaf_vec_; + size_t counter = 0; + + using IterVisitor::handle; + + void handle(Expr* e) override { + switch (e->getExprType().value()) { + case (ExprType::Split): + case (ExprType::Merge): + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid expr type found in transform traversal."); + } + IterVisitor::handle(e); + } + + // TODO: HANDLE RFACTOR DOMAINS + virtual void handle(Split* s) override { + auto id_in = s->in(); + auto it = id_map_.find(id_in); + TORCH_INTERNAL_ASSERT( + it != id_map_.end(), + "Transform traversal failed, dependencies not met."); + auto mapped = (*it).second; + TORCH_INTERNAL_ASSERT( + s->factor()->isConst(), + "Transform traversal does not support splitting on non-const values."); + auto outs = IterDomain::split(mapped, s->factor()->value().value()); + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + leaf_ids_.erase(mapped); + leaf_ids_[outs.first] = counter++; + leaf_ids_[outs.second] = counter++; + id_map_[s->outer()] = outs.first; + id_map_[s->inner()] = outs.second; + } + + virtual void handle(Merge* m) override { + auto id_outer = m->outer(); + auto id_inner = m->inner(); + auto it_outer = id_map_.find(id_outer); + auto it_inner = id_map_.find(id_inner); + TORCH_INTERNAL_ASSERT( + it_outer != id_map_.end() && it_inner != id_map_.end(), + "Transform traversal failed, dependencies not met."); + + auto id_outer_mapped = (*it_outer).second; + auto id_inner_mapped = (*it_inner).second; + + auto out = IterDomain::merge(id_outer_mapped, id_inner_mapped); + + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && + leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), + "Transform traversal failed, modified ", + id_outer_mapped, + " and ", + id_inner_mapped, + " however one or both are not leaf nodes."); + + leaf_ids_.erase(id_outer_mapped); + leaf_ids_.erase(id_inner_mapped); + leaf_ids_[out] = counter++; + id_map_[m->out()] = out; + } + + // Replays outputs that were generated from ids.first on ids.second + void runReplay() { + if (target_domain_.empty() || id_map_.empty()) + return; + + // Switch outDomain to a vector to start the traversal + std::vector traversal_vals( + target_domain_.begin(), target_domain_.end()); + traverseFrom(traversal_vals[0]->fusion(), traversal_vals); + + TORCH_INTERNAL_ASSERT( + leaf_ids_.size() >= target_domain_.size(), + "Transform traversal failed, did not find enough output IterDomains."); + + // Validate replay + size_t it = 0; + for (auto out : target_domain_) { + auto it_replayed = id_map_.find(out); + TORCH_INTERNAL_ASSERT( + it_replayed != id_map_.end(), + "Transform traversal failed, could not find expected output."); + auto id_replayed = (*it_replayed).second; + auto it_leaf = leaf_ids_.find(id_replayed); + TORCH_INTERNAL_ASSERT( + it_leaf != leaf_ids_.end(), + "Transform Traversal failed, expected matched output to be a leaf of the replay, but was not."); + } + } + + public: + ReplayTransformations( + std::vector _target_domain, + std::unordered_map _id_map) + : target_domain_(std::move(_target_domain)), id_map_(std::move(_id_map)) { + // Make sure id_map has all the inputs needed to replay target_domain + auto inps = IterVisitor::getInputsTo( + std::vector(target_domain_.begin(), target_domain_.end())); + + std::for_each(inps.begin(), inps.end(), [this](Val* val) { + TORCH_INTERNAL_ASSERT( + val->getValType().value() == ValType::IterDomain, + "Expected IterDomain only for Replay Transformations, but found ", + val); + IterDomain* id = static_cast(val); + TORCH_INTERNAL_ASSERT( + this->id_map_.find(id) != this->id_map_.end(), + "Could not find required input: ", + id, + " in provided id_map."); + }); + + // Set all the leaf nodes for tracking, all ids start as a leaf and will be + // updated based on the transformations + for (auto entry : id_map_) + leaf_ids_[entry.second] = counter++; + + runReplay(); + + // Populate leaf_vec_ in a deterministic manner. This is deterministic + // because size_t in leaf_ids is filled based on operation order. + std::set, id_int_lt> ordered_set; + for (auto entry : leaf_ids_) + ordered_set.emplace(entry); + + leaf_vec_.clear(); + leaf_vec_.resize(ordered_set.size()); + std::transform( + ordered_set.begin(), + ordered_set.end(), + leaf_vec_.begin(), + [](std::pair entry) { return entry.first; }); + } + + // Returns map from provided target domain to their corresponding IDs + const std::unordered_map& getReplay() const + noexcept { + return id_map_; + } + + // + const std::unordered_map& getUnorderedLeafIDs() const + noexcept { + return leaf_ids_; + } + + // Returns all terminating IDs that resulted from the replay. Leaf IDs are run + // to run deterministic, but otherwise in no specific order. + const std::vector& getLeafIDs() const noexcept { + return leaf_vec_; + } +}; + /* * TransformIter iterates on the split/merge graph of TensorDomain * diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index c57272754b0af..b083b1e1d5614 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -1,14 +1,14 @@ -// #include -// #include -// #include -// #include -// #include +#include +#include +#include +#include +#include -// #include +#include -// namespace torch { -// namespace jit { -// namespace fuser { +namespace torch { +namespace jit { +namespace fuser { // // Replay producer as consumer. // TensorDomain* TransformReplay::fullSelfReplay( @@ -43,166 +43,417 @@ // return replayed; // } -// // Replay producer as consumer. -// TensorDomain* TransformReplay::replayPasC( -// TensorDomain* producer, -// TensorDomain* consumer, -// int consumer_compute_at_axis) { -// if (consumer_compute_at_axis < 0) -// consumer_compute_at_axis += (int)consumer->nDims() + 1; -// TORCH_INTERNAL_ASSERT( -// consumer_compute_at_axis >= 0 && -// (unsigned int)consumer_compute_at_axis <= consumer->nDims(), -// "Invalid axis in transform replayPasC."); - -// // Consumer in rfactor cases is based off producer's rfactor root, not -// // producer's root -// TensorDomain* producer_rfactor_root = -// TransformIter::getRFactorRoot(producer); - -// // Want full consumer root, even before rfactor -// TensorDomain* consumer_root = TransformIter::getRoot(consumer); - -// // We want to see which axes in the consumer root were modified to create -// axes -// // < consumer_compute_at_axis -// std::vector consumer_influence(consumer->nDims(), false); -// for (int i = 0; i < consumer_compute_at_axis; i++) -// consumer_influence[i] = true; - -// // Check which axes in ref_root need to be modified to honor -// transformations -// // to compute at axis -// std::vector consumer_root_influence = -// TransformIter::getRootInfluence(consumer, consumer_influence); - -// // We have the influence on the consumer root, we need it on producer, we -// // want to keep those axes that don't need to be modified by the replay -// std::vector producer_rfactor_root_influence( -// producer_rfactor_root->nDims(), false); - -// // Map is based on producer -// std::vector replay_axis_map(consumer_root->nDims(), -1); -// // Setup producer_rfactor_root_influence vector on root for replay -// size_t ip = 0, ic = 0; - -// while (ip < producer_rfactor_root_influence.size() && -// ic < consumer_root->nDims()) { -// bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); -// bool is_bcast = consumer_root->axis(ic)->isBroadcast(); -// if (is_reduction) { -// producer_rfactor_root_influence[ip++] = false; -// } else if (is_bcast) { -// replay_axis_map[ic++] = -1; -// } else { -// if (consumer_root_influence[ic]) { -// replay_axis_map[ic] = ip; -// } else { -// replay_axis_map[ic] = -1; -// } -// producer_rfactor_root_influence[ip++] = consumer_root_influence[ic++]; -// } -// } - -// for (decltype(producer_rfactor_root->nDims()) i{0}; -// i < producer_rfactor_root->nDims(); -// i++) -// TORCH_INTERNAL_ASSERT( -// !(producer_rfactor_root_influence[i] && -// producer_rfactor_root->axis(i)->isRFactorProduct()), -// "An illegal attempt to modify an rfactor axis detected."); - -// // We should have hit the end of the consumer root domain -// TORCH_INTERNAL_ASSERT( -// ic == consumer_root->nDims() || -// (ic < consumer_root->nDims() ? -// consumer_root->axis(ic)->isBroadcast() -// : false), -// "Error when trying to run replay, didn't reach end of consumer/target -// root."); - -// TORCH_INTERNAL_ASSERT( -// producer_rfactor_root_influence.size() == -// producer_rfactor_root->nDims(), "Error detected during replay, expected -// matching sizes of influence map to root dimensions."); - -// auto producer_root_influence = TransformIter::getRootInfluence( -// producer_rfactor_root, producer_rfactor_root_influence); - -// TensorDomain* producer_root = -// TransformIter::getRoot(producer_rfactor_root); - -// std::vector producer_replay_map(producer_root->nDims()); -// for (decltype(producer_replay_map.size()) i{0}; -// i < producer_replay_map.size(); -// i++) { -// if (producer_root->axis(i)->isRFactorProduct()) { -// producer_replay_map[i] = i; -// } else { -// producer_replay_map[i] = producer_root_influence[i] ? -1 : i; -// } -// } - -// // Replay axes that won't be modified by transform replay -// TensorDomain* producer_replay_root = TransformIter::replaySelf( -// producer, TransformIter::getHistory(producer), producer_replay_map); - -// // Record axes positions. -// std::unordered_map new_position; -// for (decltype(producer_replay_root->nDims()) i{0}; -// i < producer_replay_root->nDims(); -// i++) -// new_position[producer_replay_root->axis(i)] = i; - -// std::unordered_map root_axis_map; -// // reorder producer_replay_root to respect replay_axis_map -// for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); -// i++) { -// if (replay_axis_map[i] == -1) -// continue; -// auto ax = producer_root->axis(replay_axis_map[i]); -// TORCH_INTERNAL_ASSERT( -// new_position.find(ax) != new_position.end(), -// "Error hit during transform replay, could not find ", -// ax, -// " expected in root domain."); -// root_axis_map[new_position[ax]] = replay_axis_map[i]; -// } - -// // root_axis_map is now mapping from producer_replay_root -> consumer_root -// // Take producer_replay_root transform for all modified axes are in correct -// // relative order, matching how it was in replay_axis_map -// producer_replay_root = producer_replay_root->reorder(root_axis_map); - -// // Finally replay producer as consumer on marked axes -// TensorDomain* replayed = TransformIter::replay( -// producer_replay_root, -// TransformIter::getHistory(consumer), -// replay_axis_map); - -// TORCH_INTERNAL_ASSERT( -// std::none_of( -// replayed->domain().begin(), -// replayed->domain().begin() + consumer_compute_at_axis, -// [](IterDomain* id) { return id->isReduction(); }), -// "Reduction axes found within consumer_compute_at_axis in replay of -// producer."); - -// return replayed; -// } - -// // Replay consumer as producer. -// TensorDomain* TransformReplay::replayCasP( -// TensorDomain* consumer, -// TensorDomain* producer, -// int producer_compute_at_axis) { -// if (producer_compute_at_axis < 0) -// producer_compute_at_axis += (int)producer->nDims() + 1; -// TORCH_INTERNAL_ASSERT( -// producer_compute_at_axis >= 0 && -// (unsigned int)producer_compute_at_axis <= producer->nDims(), -// "Invalid axis in transform replayPasC."); - +// Replay producer as consumer. +TensorDomain* TransformReplay::replayPasC( + TensorDomain* producer, + TensorDomain* consumer, + int consumer_compute_at_axis) { + if (consumer_compute_at_axis < 0) + consumer_compute_at_axis += (int)consumer->nDims() + 1; + TORCH_INTERNAL_ASSERT( + consumer_compute_at_axis >= 0 && + (unsigned int)consumer_compute_at_axis <= consumer->nDims(), + "Invalid axis in transform replayPasC."); + + // consumer ids we need to match in producer + std::vector consumer_ids; + { + size_t itc = 0; + while (itc < consumer_compute_at_axis) { + if (consumer->axis(itc)->isBroadcast()) { + itc++; + } else { + consumer_ids.emplace_back(consumer->axis(itc++)); + } + } + } + + // Figure out all inputs required to generate the compute_at dimensions + std::set consumer_root_ids = IterVisitor::getInputsTo( + std::vector(consumer_ids.begin(), consumer_ids.end())); + + // Map of consumer_root_ids to related producer_ids + std::unordered_map root_axis_map; + + // Grab root domains of producer and consumer + std::vector consumer_root = consumer->rootDomain(); + std::vector producer_root = producer->rootDomain(); + // Track which root axes in producer we send to replay + std::set producer_mapped_roots; + // Map related axes from producer and consumer roots. Make sure we go to the + // end of both. + { + size_t itc = 0, itp = 0; + while (itc < consumer_root.size() || itp < producer_root.size()) { + if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) { + itc++; + continue; + } + if (itp < producer_root.size() && producer_root[itp]->isReduction()) { + itp++; + continue; + } + TORCH_INTERNAL_ASSERT( + itc < consumer_root.size() && itp < producer_root.size(), + "Error during replay, wanted to keep going, but ran out of root dimensions."); + + if (consumer_root_ids.find(consumer_root[itc]) != + consumer_root_ids.end()) { + root_axis_map[consumer_root[itc]] = producer_root[itp]; + producer_mapped_roots.emplace(producer_root[itp]); + } + itc++; + itp++; + } + } + + // Replay producer dimensions. + ReplayTransformations replay_producer(consumer_ids, root_axis_map); + + // replay_producer now contains mappings from consumer axes to their replayed + // counter parts in producer (including intermediate IDs, not just those in + // consumer_root or conusmer). replay_producer also has all the leaf + // IterDomains from the replay. + + // Find all axes that were not modified during the replay. + std::vector unmodified_producer_axes; + // Run through all producer ids + for (auto producer_id : producer->domain()) { + // Grab the input to this id + auto inps = IterVisitor::getInputsTo({producer_id}); + + // If all inputs to this id are leaf nodes, then this id wasn't impacted + bool modified = false; + for (auto inp : inps) { + if (inp->getValType().value() != ValType::IterDomain) + continue; + IterDomain* inp_id = static_cast(inp); + // if ( we sent this root id to replay && it was modified ) + if (producer_mapped_roots.find(inp_id) != producer_mapped_roots.end() && + replay_producer.getUnorderedLeafIDs().find(inp_id) == + replay_producer.getUnorderedLeafIDs().end()) { + modified = true; + break; + } + } + + // If not impacted, lets put it back in the replayed domain. + if (!modified) + unmodified_producer_axes.emplace_back(producer_id); + } + + // (1) replay_producer.getReplay holds mappings from axes in consumer -> + // generated axes in producer (2) replay_producer.getLeafIDs holds a + // determinstica ordering of axes in (1), and all other leaf axes created in + // generating the above (3) unmodified_producer_axes holds axes that didn't + // have to be modified to generate (1) + + // Accumulate new domain in this vector: + std::vector new_IDs; + + // Add axes in (1) + std::set used_IDs; + for (auto c_id : consumer_ids) { + auto it = replay_producer.getReplay().find(c_id); + TORCH_INTERNAL_ASSERT( + it != replay_producer.getReplay().end(), + "Could not find axis, ", + c_id, + ", requested in replay."); + new_IDs.push_back((*it).second); + used_IDs.emplace((*it).second); + } + + // Add axes in (2) + for (auto leaf : replay_producer.getLeafIDs()) + if (used_IDs.find(leaf) == used_IDs.end()) + new_IDs.push_back(leaf); + + // Add axes in (3) + for (auto unmodified : unmodified_producer_axes) + if (used_IDs.find(unmodified) == used_IDs.end()) + new_IDs.push_back(unmodified); + + TensorDomain* replayed = new TensorDomain(producer_root, new_IDs); + + return replayed; + + // KEEPING BELOW FOR NOW FOR REFERENCE WHEN DOING RFACTOR! + // // Consumer in rfactor cases is based off producer's rfactor root, not + // // producer's root + // TensorDomain* producer_rfactor_root = + // TransformIter::getRFactorRoot(producer); + + // // Want full consumer root, even before rfactor + // TensorDomain* consumer_root = TransformIter::getRoot(consumer); + + // // We want to see which axes in the consumer root were modified to create + // axes + // // < consumer_compute_at_axis + // std::vector consumer_influence(consumer->nDims(), false); + // for (int i = 0; i < consumer_compute_at_axis; i++) + // consumer_influence[i] = true; + + // // Check which axes in ref_root need to be modified to honor + // transformations + // // to compute at axis + // std::vector consumer_root_influence = + // TransformIter::getRootInfluence(consumer, consumer_influence); + + // // We have the influence on the consumer root, we need it on producer, we + // // want to keep those axes that don't need to be modified by the replay + // std::vector producer_rfactor_root_influence( + // producer_rfactor_root->nDims(), false); + + // // Map is based on producer + // std::vector replay_axis_map(consumer_root->nDims(), -1); + // // Setup producer_rfactor_root_influence vector on root for replay + // size_t ip = 0, ic = 0; + + // while (ip < producer_rfactor_root_influence.size() && + // ic < consumer_root->nDims()) { + // bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); + // bool is_bcast = consumer_root->axis(ic)->isBroadcast(); + // if (is_reduction) { + // producer_rfactor_root_influence[ip++] = false; + // } else if (is_bcast) { + // replay_axis_map[ic++] = -1; + // } else { + // if (consumer_root_influence[ic]) { + // replay_axis_map[ic] = ip; + // } else { + // replay_axis_map[ic] = -1; + // } + // producer_rfactor_root_influence[ip++] = consumer_root_influence[ic++]; + // } + // } + + // for (decltype(producer_rfactor_root->nDims()) i{0}; + // i < producer_rfactor_root->nDims(); + // i++) + // TORCH_INTERNAL_ASSERT( + // !(producer_rfactor_root_influence[i] && + // producer_rfactor_root->axis(i)->isRFactorProduct()), + // "An illegal attempt to modify an rfactor axis detected."); + + // // We should have hit the end of the consumer root domain + // TORCH_INTERNAL_ASSERT( + // ic == consumer_root->nDims() || + // (ic < consumer_root->nDims() ? + // consumer_root->axis(ic)->isBroadcast() + // : false), + // "Error when trying to run replay, didn't reach end of consumer/target + // root."); + + // TORCH_INTERNAL_ASSERT( + // producer_rfactor_root_influence.size() == + // producer_rfactor_root->nDims(), "Error detected during replay, expected + // matching sizes of influence map to root dimensions."); + + // auto producer_root_influence = TransformIter::getRootInfluence( + // producer_rfactor_root, producer_rfactor_root_influence); + + // TensorDomain* producer_root = + // TransformIter::getRoot(producer_rfactor_root); + + // std::vector producer_replay_map(producer_root->nDims()); + // for (decltype(producer_replay_map.size()) i{0}; + // i < producer_replay_map.size(); + // i++) { + // if (producer_root->axis(i)->isRFactorProduct()) { + // producer_replay_map[i] = i; + // } else { + // producer_replay_map[i] = producer_root_influence[i] ? -1 : i; + // } + // } + + // // Replay axes that won't be modified by transform replay + // TensorDomain* producer_replay_root = TransformIter::replaySelf( + // producer, TransformIter::getHistory(producer), producer_replay_map); + + // // Record axes positions. + // std::unordered_map new_position; + // for (decltype(producer_replay_root->nDims()) i{0}; + // i < producer_replay_root->nDims(); + // i++) + // new_position[producer_replay_root->axis(i)] = i; + + // std::unordered_map root_axis_map; + // // reorder producer_replay_root to respect replay_axis_map + // for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); + // i++) { + // if (replay_axis_map[i] == -1) + // continue; + // auto ax = producer_root->axis(replay_axis_map[i]); + // TORCH_INTERNAL_ASSERT( + // new_position.find(ax) != new_position.end(), + // "Error hit during transform replay, could not find ", + // ax, + // " expected in root domain."); + // root_axis_map[new_position[ax]] = replay_axis_map[i]; + // } + + // // root_axis_map is now mapping from producer_replay_root -> consumer_root + // // Take producer_replay_root transform for all modified axes are in correct + // // relative order, matching how it was in replay_axis_map + // producer_replay_root = producer_replay_root->reorder(root_axis_map); + + // // Finally replay producer as consumer on marked axes + // TensorDomain* replayed = TransformIter::replay( + // producer_replay_root, + // TransformIter::getHistory(consumer), + // replay_axis_map); + + // TORCH_INTERNAL_ASSERT( + // std::none_of( + // replayed->domain().begin(), + // replayed->domain().begin() + consumer_compute_at_axis, + // [](IterDomain* id) { return id->isReduction(); }), + // "Reduction axes found within consumer_compute_at_axis in replay of + // producer."); + + // return replayed; +} + +// Replay consumer as producer. +TensorDomain* TransformReplay::replayCasP( + TensorDomain* consumer, + TensorDomain* producer, + int producer_compute_at_axis) { + if (producer_compute_at_axis < 0) + producer_compute_at_axis += (int)producer->nDims() + 1; + TORCH_INTERNAL_ASSERT( + producer_compute_at_axis >= 0 && + (unsigned int)producer_compute_at_axis <= producer->nDims(), + "Invalid axis in transform replayPasC."); + + // producer ids we need to match in consumer + std::vector producer_ids; + { + size_t itp = 0; + while (itp < producer_compute_at_axis) { + if (producer->axis(itp)->isReduction()) { + itp++; + } else { + producer_ids.emplace_back(producer->axis(itp++)); + } + } + } + + // Figure out all inputs required to generate the compute_at dimensions + std::set producer_root_ids = IterVisitor::getInputsTo( + std::vector(producer_ids.begin(), producer_ids.end())); + + // Map of producer_root_ids to related producer_ids + std::unordered_map root_axis_map; + + // Grab root domains of producer and consumer + std::vector consumer_root = consumer->rootDomain(); + std::vector producer_root = producer->rootDomain(); + + // Track which root axes in consumer we send to replay + std::set consumer_mapped_roots; + // Map related axes from producer and consumer roots. Make sure we go to the + // end of both. + { + size_t itc = 0, itp = 0; + while (itc < consumer_root.size() || itp < producer_root.size()) { + if (itc < consumer_root.size() && consumer_root[itc]->isBroadcast()) { + itc++; + continue; + } + if (itp < producer_root.size() && producer_root[itp]->isReduction()) { + itp++; + continue; + } + TORCH_INTERNAL_ASSERT( + itc < consumer_root.size() && itp < producer_root.size(), + "Error during replay, wanted to keep going, but ran out of root dimensions."); + + if (producer_root_ids.find(producer_root[itp]) != + producer_root_ids.end()) { + root_axis_map[producer_root[itp]] = consumer_root[itc]; + consumer_mapped_roots.emplace(consumer_root[itc]); + } + itc++; + itp++; + } + } + + // Replay producer dimensions. + ReplayTransformations replay_consumer(producer_ids, root_axis_map); + + // replay_consumer now contains mappings from producer axes to their replayed + // counter parts in consumer (including intermediate IDs, not just those in + // producer_root or producer). replay_consumer also has all the leaf + // IterDomains from the replay. + + // Find all axes that were not modified during the replay. + std::vector unmodified_consumer_axes; + // Run through all consumer ids + for (auto consumer_id : consumer->domain()) { + // Grab the input to this id + auto inps = IterVisitor::getInputsTo({consumer_id}); + + // If all inputs to this id are leaf nodes, then this id wasn't impacted + bool modified = false; + for (auto inp : inps) { + if (inp->getValType().value() != ValType::IterDomain) + continue; + + IterDomain* inp_id = static_cast(inp); + // if ( we sent this root id to replay && it was modified ) + if (consumer_mapped_roots.find(inp_id) != consumer_mapped_roots.end() && + replay_consumer.getUnorderedLeafIDs().find(inp_id) == + replay_consumer.getUnorderedLeafIDs().end()) { + modified = true; + break; + } + } + + // If not impacted, lets put it back in the replayed domain. + if (!modified) + unmodified_consumer_axes.emplace_back(consumer_id); + } + + // (1) replay_consumer.getReplay holds mappings from axes in producer -> + // generated axes in consumer (2) replay_consumer.getLeafIDs holds a + // determinstica ordering of axes in (1), and all other leaf axes created in + // generating the above (3) unmodified_consumer_axes holds axes that didn't + // have to be modified to generate (1) + + // Accumulate new domain in this vector: + std::vector new_IDs; + + // Add axes in (1) + std::set used_IDs; + for (auto p_id : producer_ids) { + auto it = replay_consumer.getReplay().find(p_id); + TORCH_INTERNAL_ASSERT( + it != replay_consumer.getReplay().end(), + "Could not find axis, ", + p_id, + ", requested in replay."); + new_IDs.push_back((*it).second); + used_IDs.emplace((*it).second); + } + + // Add axes in (2) + for (auto leaf : replay_consumer.getLeafIDs()) + if (used_IDs.find(leaf) == used_IDs.end()) + new_IDs.push_back(leaf); + + // Add axes in (3) + for (auto unmodified : unmodified_consumer_axes) + if (used_IDs.find(unmodified) == used_IDs.end()) + new_IDs.push_back(unmodified); + + TensorDomain* replayed = new TensorDomain(consumer_root, new_IDs); + + return replayed; +} + +// KEEPING BELOW FOR NOW FOR REFERENCE WHEN DOING RFACTOR! // // Want producer root with no reductions, rfactor included // TensorDomain* producer_rfactor_root = // TransformIter::getRFactorRoot(producer); TensorDomain* producer_root = @@ -326,38 +577,37 @@ // return replayed; // } -// // replay Producer as Consumer -// TensorView* TransformReplay::replayPasC( -// TensorView* producer, -// TensorView* consumer, -// int compute_at_axis) { -// // If this is a reduction operation, we may call transform_replay on the -// same -// // tensor view. When this happens, just return thet target view. -// if (producer == consumer) -// return producer; - -// TensorDomain* td = -// replayPasC(producer->domain(), consumer->domain(), compute_at_axis); -// producer->setDomain(td); -// return producer; -// } - -// TensorView* TransformReplay::replayCasP( -// TensorView* consumer, -// TensorView* producer, -// int compute_at_axis) { -// // If this is a reduction operation, we may call transform_replay on the -// same -// // tensor view. When this happens, just return thet target view. -// if (consumer == producer) -// return consumer; -// TensorDomain* td = -// replayCasP(consumer->domain(), producer->domain(), compute_at_axis); -// consumer->setDomain(td); -// return consumer; -// } - -// } // namespace fuser -// } // namespace jit -// } // namespace torch +// replay Producer as Consumer +TensorView* TransformReplay::replayPasC( + TensorView* producer, + TensorView* consumer, + int compute_at_axis) { + // If this is a reduction operation, we may call transform_replay on the + + // tensor view. When this happens, just return thet target view. + if (producer == consumer) + return producer; + + TensorDomain* td = + replayPasC(producer->domain(), consumer->domain(), compute_at_axis); + producer->setDomain(td); + return producer; +} + +TensorView* TransformReplay::replayCasP( + TensorView* consumer, + TensorView* producer, + int compute_at_axis) { + // If this is a reduction operation, we may call transform_replay on the same + // tensor view. When this happens, just return thet target view. + if (consumer == producer) + return consumer; + TensorDomain* td = + replayCasP(consumer->domain(), producer->domain(), compute_at_axis); + consumer->setDomain(td); + return consumer; +} + +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 96c353d44ad27..66a4b7b6ed774 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -132,33 +133,25 @@ struct TORCH_CUDA_API TransformReplay { static TensorDomain* replayPasC( TensorDomain* producer, TensorDomain* consumer, - int compute_at_axis) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + int consumer_compute_at_axis); // Replay producer as consumer. static TensorView* replayPasC( TensorView* producer, TensorView* consumer, - int compute_at_axis) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + int consumer_compute_at_axis); // Replay producer as consumer. static TensorDomain* replayCasP( TensorDomain* consumer, TensorDomain* producer, - int compute_at_axis) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + int producer_compute_at_axis); // Replay producer as consumer. static TensorView* replayCasP( TensorView* consumer, TensorView* producer, - int compute_at_axis) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + int producer_compute_at_axis); }; } // namespace fuser From c112c14c89162b552f17b78547025cd2fa7e7ecf Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 27 May 2020 13:32:08 -0400 Subject: [PATCH 7/8] Add Backward(Iter)Visitor. Get pointwise operations working again. --- torch/csrc/jit/codegen/cuda/fusion.cpp | 14 +- torch/csrc/jit/codegen/cuda/fusion.h | 20 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 790 +++++++++--------- torch/csrc/jit/codegen/cuda/index_compute.h | 64 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 7 +- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 160 +++- torch/csrc/jit/codegen/cuda/iter_visitor.h | 80 +- torch/csrc/jit/codegen/cuda/lower2device.cpp | 7 +- torch/csrc/jit/codegen/cuda/transform_iter.h | 34 +- .../jit/codegen/cuda/transform_replay.cpp | 59 +- .../csrc/jit/codegen/cuda/transform_replay.h | 12 +- .../jit/codegen/cuda/transform_rfactor.cpp | 436 +++++----- .../csrc/jit/codegen/cuda/transform_rfactor.h | 8 +- 13 files changed, 967 insertions(+), 724 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 6623392ec40a0..142f83a0b8e2a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -39,7 +39,7 @@ void InputsOf::handle(Val* v) { inputs.emplace(v); } -std::set InputsOf::output(Fusion* fusion, Val* output_) { +std::unordered_set InputsOf::output(Fusion* fusion, Val* output_) { TORCH_CHECK( fusion->hasOutput(output_), "Asked for the inputs of ", @@ -177,12 +177,12 @@ std::vector Fusion::exprs(bool from_outputs_only, bool breadth_first) { return ExprSort::getExprs(this, from_outputs_only, breadth_first); } -std::set Fusion::inputsOf(Val* val) { +std::unordered_set Fusion::inputsOf(Val* val) { return InputsOf::output(this, val); } void Fusion::validateInputs() { - std::set all_inputs; + std::unordered_set all_inputs; for (Val* out : outputs()) { auto outs_inputs = inputsOf(out); std::set_union( @@ -293,7 +293,7 @@ bool Fusion::used(Val* val) const { (uses_.find(val)->second.size() > 0); } -const std::set& Fusion::vals() const noexcept { +const std::unordered_set& Fusion::vals() const noexcept { return val_set_; } @@ -301,17 +301,17 @@ const std::deque& Fusion::deterministic_vals() const noexcept { return val_deque_; } -const std::set& Fusion::unordered_exprs() const noexcept { +const std::unordered_set& Fusion::unordered_exprs() const noexcept { return expr_set_; } -std::set Fusion::uses(Val* val) const { +std::unordered_set Fusion::uses(Val* val) const { assertInFusion(val, "Cannot detect where val was used, "); if (uses_.find(val) != uses_.end()) { auto ret = uses_.find(val)->second; return ret; } - return std::set(); + return std::unordered_set(); } Expr* Fusion::origin(Val* val) const { diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index 3201a95a57503..26a022ce43854 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include namespace torch { @@ -82,12 +82,12 @@ struct ExprSort : public IterVisitor { struct InputsOf : public IterVisitor { private: - std::set inputs; + std::unordered_set inputs; void handle(Val* v) final; public: - static std::set output(Fusion* fusion, Val* output_); + static std::unordered_set output(Fusion* fusion, Val* output_); }; /* @@ -150,7 +150,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { bool from_outputs_only = false, bool breadth_first = false); - std::set inputsOf(Val* val); + std::unordered_set inputsOf(Val* val); // Assert that all leaves found from outputs are registered as an input. void validateInputs(); @@ -179,15 +179,15 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { bool used(Val* val) const; // Return the set of Vals registered with this fusion - const std::set& vals() const noexcept; + const std::unordered_set& vals() const noexcept; // Return in insertion order const std::deque& deterministic_vals() const noexcept; // Return the set of Exprs registered with this fusion - const std::set& unordered_exprs() const noexcept; + const std::unordered_set& unordered_exprs() const noexcept; // Return all Exprs that use val - std::set uses(Val* val) const; + std::unordered_set uses(Val* val) const; // Return the Expr that produces val Expr* origin(Val* val) const; @@ -203,9 +203,9 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { private: // Sets of all Vals/Exprs registered with this fusion - std::set val_set_; + std::unordered_set val_set_; std::deque val_deque_; - std::set expr_set_; + std::unordered_set expr_set_; // Return an int that monotonically increases for each val/expr, some are // explicitly incremented by type. @@ -225,7 +225,7 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput { // Dependency tracking for Vals. Where did it come from? Where is it used? std::unordered_map origin_; - std::unordered_map> uses_; + std::unordered_map> uses_; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 4db45509439af..098dce541eec6 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -1,388 +1,402 @@ -// #include -// #include -// #include -// #include - -// namespace torch { -// namespace jit { -// namespace fuser { - -// TensorDomain* IndexCompute::replayBackward(Split* split, TensorDomain*) { -// int ax = split->axis(); -// TORCH_INTERNAL_ASSERT( -// ax >= 0 && (unsigned int)(ax + 1) < indices.size(), -// "Hit an invalid Split transformation during IndexCompute, axis is not -// within bounds."); -// indices[ax] = add(mul(indices[ax], split->factor()), indices[ax + 1]); -// indices.erase(indices.begin() + ax + 1); -// return split->in(); -// } - -// TensorDomain* IndexCompute::replayBackward(Merge* merge, TensorDomain*) { -// int ax = merge->axis(); -// TORCH_INTERNAL_ASSERT( -// ax >= 0 && (unsigned int)ax < indices.size(), -// "Hit an invalid MERGE transformation during IndexCompute, axis is not -// within bounds."); - -// Val* I = merge->in()->axis(ax + 1)->extent(); -// Val* ind = indices[ax]; -// indices[ax] = div(ind, I); -// indices.insert(indices.begin() + ax + 1, mod(ind, I)); -// return merge->in(); -// } - -// TensorDomain* IndexCompute::runBackward(std::vector history) { -// TensorDomain* running_td = nullptr; -// for (auto it = history.rbegin(); it != history.rend(); it++) -// running_td = TransformIter::replayBackward(*it, running_td); - -// return running_td; -// } - -// IndexCompute::IndexCompute(TensorDomain* td, std::vector _indices) -// : indices(std::move(_indices)) { -// bool exclude_reduction = td->nDims() > indices.size(); - -// TORCH_INTERNAL_ASSERT( -// td->noReductions().size() == indices.size() || -// td->nDims() == indices.size(), -// "For IndexCompute the number of axes should match the number of -// dimensions" " in the TensorDomain."); - -// // If we need to ignore the reduction dimensions because a tensor is -// // being consumed, not produced, then insert dummy dimensions in the -// // indices for bookkeeping while replaying split/merge operations. -// if (exclude_reduction) -// for (decltype(td->nDims()) i{0}; i < td->nDims(); i++) -// if (td->axis(i)->isReduction()) -// indices.insert(indices.begin() + i, new Int(-1)); - -// TORCH_INTERNAL_ASSERT( -// indices.size() == td->nDims(), -// "Attempted to modify indices for IndexCompute, but didn't work."); - -// // Run the split/merge operations backwards. This will -// // Modify std::vector indices so it can be used to index -// // the root TensorDomain which should now match the physical axes. -// TensorDomain* root = TransformIter::getRoot(td); -// auto history = TransformIter::getHistory(td); -// if (exclude_reduction && td->hasRFactor()) { -// root = TransformIter::getRFactorRoot(td); -// auto rfactor_history = TransformIter::getHistory(root); -// history.erase(history.begin(), history.begin() + rfactor_history.size()); -// } - -// runBackward(history); - -// TORCH_INTERNAL_ASSERT( -// root->nDims() == indices.size(), -// "Error during IndexCompute. The number of indices generated" -// " after running the transformations backwards should match" -// " the number of dimensions of the root TensorDomain."); -// } - -// std::vector IndexCompute::get( -// TensorDomain* td, -// const std::vector& _indices) { -// IndexCompute ic(td, _indices); -// return ic.indices; -// } - -// TensorIndex* Index::getGlobalProducerIndex( -// TensorView* producer, -// TensorView* consumer, -// const std::vector& loops) { -// // This replay will ignore reduction dimensions on the producer -// auto pind = -// TransformReplay::replayPasC(producer->domain(), consumer->domain(), -// -1); - -// TORCH_INTERNAL_ASSERT( -// loops.size() == consumer->nDims(), -// "Dimensionality error in code generator while computing tensor -// indexes."); - -// std::vector loops_adjusted; -// size_t it_c = 0, it_p = 0; -// while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) { -// if (consumer->axis(it_c)->isBroadcast() && -// !pind->noReductions()[it_p]->isBroadcast()) { -// it_c++; -// } else { -// loops_adjusted.push_back(loops[it_c]); -// it_c++; -// it_p++; -// } -// } - -// TORCH_INTERNAL_ASSERT( -// loops_adjusted.size() == pind->noReductions().size(), -// "Dimensionality error in code generator while computing tensor -// indexes."); - -// std::vector indices(loops_adjusted.size()); -// std::transform( -// loops_adjusted.begin(), -// loops_adjusted.end(), -// indices.begin(), -// [](ForLoop* fl) { return fl->index(); }); -// std::vector computed_inds = IndexCompute::get(pind, indices); - -// auto root_domain = producer->getRootDomain(); - -// TORCH_INTERNAL_ASSERT( -// computed_inds.size() == root_domain->nDims(), -// "Dimensionality error in code generator while computing indexing."); - -// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { -// if (root_domain->axis(i)->isReduction() || -// root_domain->axis(i)->isBroadcast()) -// computed_inds.erase(computed_inds.begin() + i); -// } - -// std::vector strided_inds; -// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { -// std::stringstream ss; -// ss << "T" << producer->name() << ".stride[" << i << "]"; -// strided_inds.push_back( -// mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); -// } - -// // Probably shouldn't ever hit this -// if (strided_inds.size() == 0) -// strided_inds.push_back(new Int(0)); - -// return new TensorIndex(producer, strided_inds); -// } - -// // Producer index for either shared or local memory -// TensorIndex* Index::getProducerIndex_impl( -// TensorView* producer, -// TensorView* consumer, -// const std::vector& loops) { -// TORCH_INTERNAL_ASSERT( -// loops.size() == consumer->nDims(), -// "Dimensionality error in code generator while computing tensor -// indexes."); - -// std::vector loops_adjusted; -// size_t it_c = 0, it_p = 0; -// while (it_c < consumer->nDims() && it_p < producer->nDims()) { -// if (consumer->axis(it_c)->isBroadcast() && -// !producer->axis(it_p)->isBroadcast()) { -// it_c++; -// } else { -// loops_adjusted.push_back(loops[it_c]); -// it_c++; -// it_p++; -// } -// } - -// TORCH_INTERNAL_ASSERT( -// loops_adjusted.size() == producer->domain()->noReductions().size(), -// "Expected a tensor with ", -// loops_adjusted.size(), -// " dimensions but got one with ", -// producer->nDims()); - -// std::vector ranges(loops_adjusted.size()); -// std::transform( -// loops_adjusted.begin(), -// loops_adjusted.end(), -// ranges.begin(), -// [](ForLoop* fl) { return fl->iter_domain(); }); - -// std::vector indices(loops_adjusted.size()); -// std::transform( -// loops_adjusted.begin(), -// loops_adjusted.end(), -// indices.begin(), -// [](ForLoop* fl) { -// return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); -// }); - -// std::vector used_inds; -// std::vector used_ranges; -// bool unrolled = false; -// for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) -// { -// if (ranges[i]->parallel_method() == ParallelType::Unroll) -// unrolled = true; -// if (!unrolled && producer->hasComputeAt() && -// i < producer->getThisComputeAtAxis()) -// continue; -// if (producer->getMemoryType() == MemoryType::Shared && -// ranges[i]->isBlockDim()) -// continue; -// if (producer->getMemoryType() == MemoryType::Local && -// ranges[i]->isThread()) -// continue; -// if (ranges[i]->isBroadcast()) -// continue; - -// used_inds.push_back(indices[i]); -// used_ranges.push_back(ranges[i]); -// } - -// for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { -// Val* ind = used_inds[i]; -// for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) -// ind = mul(ind, used_ranges[j]->extent()); -// used_inds[i] = ind; -// } -// if (used_inds.size() == 0) -// used_inds.push_back(new Int(0)); - -// return new TensorIndex(producer, used_inds); -// } - -// TensorIndex* Index::getGlobalConsumerIndex( -// TensorView* consumer, -// const std::vector& loops) { -// // If we're initializing a reduction buffer, we won't have the reduction -// // loops. If we're actually performing the reduction, we will. - -// std::vector indices(loops.size()); -// std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) -// { -// return fl->index(); -// }); - -// std::vector computed_inds = -// IndexCompute::get(consumer->domain(), indices); - -// TensorDomain* root_dom = consumer->getRootDomain(); -// TORCH_INTERNAL_ASSERT( -// computed_inds.size() == root_dom->nDims(), -// "Dimensionality error in code generator while computing indexing."); - -// for (decltype(root_dom->nDims()) i{0}; i < root_dom->nDims(); i++) { -// // Do this backwards so erase offset will be right -// auto axis = root_dom->nDims() - i - 1; -// if (root_dom->axis(axis)->isReduction() || -// root_dom->axis(i)->isBroadcast()) -// computed_inds.erase(computed_inds.begin() + axis); -// } - -// std::vector strided_inds; -// for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { -// std::stringstream ss; -// ss << "T" << consumer->name() << ".stride[" << i << "]"; -// strided_inds.push_back( -// mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); -// } - -// // Probably shouldn't ever hit this -// if (strided_inds.size() == 0) -// strided_inds.push_back(new Int(0)); - -// return new TensorIndex(consumer, strided_inds); -// } - -// // Consumer index for either shared or local memory -// TensorIndex* Index::getConsumerIndex_impl( -// TensorView* consumer, -// const std::vector& loops) { -// // If we're initializing a reduction buffer, we won't have the reduction -// // loops. If we're actually performing the reduction, we will. - -// bool have_reduction_iters = loops.size() == consumer->nDims(); - -// if (!have_reduction_iters) { -// TORCH_INTERNAL_ASSERT( -// // Init reduction space -// loops.size() == consumer->domain()->noReductions().size(), -// "Expected a tensor with ", -// loops.size(), -// " dimensions but got one with ", -// consumer->domain()->noReductions().size()); -// } else { -// TORCH_INTERNAL_ASSERT( -// // Calling the reduction op -// loops.size() == consumer->nDims(), -// "Expected a tensor with ", -// loops.size(), -// " dimensions but got one with ", -// consumer->nDims()); -// } - -// std::vector ranges(loops.size()); -// std::transform(loops.begin(), loops.end(), ranges.begin(), [](ForLoop* fl) -// { -// return fl->iter_domain(); -// }); - -// std::vector indices(loops.size()); -// std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) -// { -// return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); -// }); - -// std::vector used_inds; -// std::vector used_ranges; -// bool unrolled = false; -// for (decltype(loops.size()) i{0}; i < loops.size(); i++) { -// if (have_reduction_iters && consumer->axis(i)->isReduction()) -// continue; -// if (ranges[i]->parallel_method() == ParallelType::Unroll) -// unrolled = true; -// if (!unrolled && consumer->hasComputeAt() && -// i < consumer->getThisComputeAtAxis()) -// continue; -// if (consumer->getMemoryType() == MemoryType::Shared && -// ranges[i]->isBlockDim()) -// continue; -// if (consumer->getMemoryType() == MemoryType::Local && -// ranges[i]->isThread()) -// continue; -// if (ranges[i]->isBroadcast()) -// continue; - -// used_inds.push_back(indices[i]); -// used_ranges.push_back(ranges[i]); -// } - -// for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { -// Val* ind = used_inds[i]; -// for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) -// ind = mul(ind, used_ranges[j]->extent()); -// used_inds[i] = ind; -// } - -// if (used_inds.size() == 0) -// used_inds.push_back(new Int(0)); - -// return new TensorIndex(consumer, used_inds); -// } - -// // Producer is the inputs of an expression -// TensorIndex* Index::getProducerIndex( -// TensorView* producer, -// TensorView* consumer, -// const std::vector& loops) { -// TORCH_INTERNAL_ASSERT( -// loops.size() == consumer->nDims() || -// loops.size() == consumer->domain()->noReductions().size()); - -// if (producer->getMemoryType() == MemoryType::Global) -// return getGlobalProducerIndex(producer, consumer, loops); -// return getProducerIndex_impl(producer, consumer, loops); -// } - -// // Consumer is the output of an expression -// TensorIndex* Index::getConsumerIndex( -// TensorView* consumer, -// const std::vector& loops) { -// TORCH_INTERNAL_ASSERT( -// loops.size() == consumer->nDims() || -// loops.size() == consumer->domain()->noReductions().size()); - -// if (consumer->getMemoryType() == MemoryType::Global) -// return getGlobalConsumerIndex(consumer, loops); -// return getConsumerIndex_impl(consumer, loops); -// } - -// } // namespace fuser -// } // namespace jit -// } // namespace torch +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { + +void IndexCompute::handle(Split* split) { + auto in_id = split->in(); + auto outer_id = split->outer(); + auto inner_id = split->inner(); + + auto outer_it = index_map_.find(outer_id); + auto inner_it = index_map_.find(inner_id); + TORCH_INTERNAL_ASSERT( + outer_it != index_map_.end() && inner_it != index_map_.end(), + "Error in index compute, did not compute a necessary intermediate value."); + + auto outer_ind = outer_it->second; + auto inner_ind = inner_it->second; + + auto ind = add(mul(outer_ind, split->factor()), inner_ind); + index_map_[in_id] = ind; +} + +void IndexCompute::handle(Merge* merge) { + auto out_id = merge->out(); + auto outer_id = merge->outer(); + auto inner_id = merge->inner(); + + auto out_it = index_map_.find(out_id); + TORCH_INTERNAL_ASSERT( + out_it != index_map_.end(), + "Error in index compute, did not compute a necessary intermediate value."); + + auto out_ind = out_it->second; + + Val* I = inner_id->extent(); + Val* outer_ind = div(out_ind, I); + Val* inner_ind = mod(out_ind, I); + + index_map_[outer_id] = outer_ind; + index_map_[inner_id] = inner_ind; +} + +void IndexCompute::handle(Expr* e) { + switch (e->getExprType().value()) { + case (ExprType::Split): + case (ExprType::Merge): + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid expr type found in transform traversal."); + } + BackwardVisitor::handle(e); +} + +IndexCompute::IndexCompute(TensorDomain* td, const std::vector& indices) { + if (td->nDims() == 0 || indices.empty()) { + indices_.push_back(new Int(0)); + return; + } + + bool exclude_reduction = td->nDims() > indices.size(); + + TORCH_INTERNAL_ASSERT( + td->noReductions().size() == indices.size() || + td->nDims() == indices.size(), + "For IndexCompute the number of axes should match the number of dimensions in the TensorDomain."); + + TORCH_INTERNAL_ASSERT( + indices.size() == td->nDims(), + "Attempted to modify indices for IndexCompute, but didn't work."); + + TORCH_INTERNAL_ASSERT(!td->hasRFactor(), "Not implemented yet."); + + { + size_t i = 0; + for (auto id : td->domain()) { + if (exclude_reduction && id->isReduction()) + continue; + index_map_[id] = indices[i++]; + } + } + + std::vector domain_vals(td->domain().begin(), td->domain().end()); + + // Run the split/merge operations backwards. This will modify the index_map_ + // so it can be used to index the root TensorDomain. Each entry in the root + // TensorDomain should have an entry in index_map_ We might not want to run + // these indices at the root of the domain, but actually at the rfactor root. + // Fortunately we can run them all the way back, but grab the indices from the + // map at the rfactor IterDomains. + traverseFrom(indices[0]->fusion(), domain_vals, false); + + std::vector inds; + for (auto id : td->rootDomain()) { + if (exclude_reduction && id->isReduction()) + continue; + auto it = index_map_.find(id); + TORCH_INTERNAL_ASSERT( + it != index_map_.end(), + "Error during index compute, missed computing a value."); + indices_.push_back(it->second); + } +} + +std::vector IndexCompute::get( + TensorDomain* td, + const std::vector& _indices) { + IndexCompute ic(td, _indices); + return ic.indices_; +} + +TensorIndex* Index::getGlobalProducerIndex( + TensorView* producer, + TensorView* consumer, + const std::vector& loops) { + // This replay will ignore reduction dimensions on the producer + auto pind = + TransformReplay::replayPasC(producer->domain(), consumer->domain(), -1); + + TORCH_INTERNAL_ASSERT( + loops.size() == consumer->nDims(), + "Dimensionality error in code generator while computing tensor indexes."); + + std::vector loops_adjusted; + size_t it_c = 0, it_p = 0; + while (it_c < consumer->nDims() && it_p < pind->noReductions().size()) { + if (consumer->axis(it_c)->isBroadcast() && + !pind->noReductions()[it_p]->isBroadcast()) { + it_c++; + } else { + loops_adjusted.push_back(loops[it_c]); + it_c++; + it_p++; + } + } + + TORCH_INTERNAL_ASSERT( + loops_adjusted.size() == pind->noReductions().size(), + "Dimensionality error in code generator while computing tensor indexes."); + + std::vector indices(loops_adjusted.size()); + std::transform( + loops_adjusted.begin(), + loops_adjusted.end(), + indices.begin(), + [](ForLoop* fl) { return fl->index(); }); + std::vector computed_inds = IndexCompute::get(pind, indices); + + auto root_domain = producer->getRootDomain(); + + TORCH_INTERNAL_ASSERT( + computed_inds.size() == root_domain.size(), + "Dimensionality error in code generator while computing indexing."); + + for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { + if (root_domain[i]->isReduction() || root_domain[i]->isBroadcast()) + computed_inds.erase(computed_inds.begin() + i); + } + + std::vector strided_inds; + for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { + std::stringstream ss; + ss << "T" << producer->name() << ".stride[" << i << "]"; + strided_inds.push_back( + mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); + } + + // Probably shouldn't ever hit this + if (strided_inds.size() == 0) + strided_inds.push_back(new Int(0)); + + return new TensorIndex(producer, strided_inds); +} + +// Producer index for either shared or local memory +TensorIndex* Index::getProducerIndex_impl( + TensorView* producer, + TensorView* consumer, + const std::vector& loops) { + TORCH_INTERNAL_ASSERT( + loops.size() == consumer->nDims(), + "Dimensionality error in code generator while computing tensor indexes."); + + std::vector loops_adjusted; + size_t it_c = 0, it_p = 0; + while (it_c < consumer->nDims() && it_p < producer->nDims()) { + if (consumer->axis(it_c)->isBroadcast() && + !producer->axis(it_p)->isBroadcast()) { + it_c++; + } else { + loops_adjusted.push_back(loops[it_c]); + it_c++; + it_p++; + } + } + + TORCH_INTERNAL_ASSERT( + loops_adjusted.size() == producer->domain()->noReductions().size(), + "Expected a tensor with ", + loops_adjusted.size(), + " dimensions but got one with ", + producer->nDims()); + + std::vector ranges(loops_adjusted.size()); + std::transform( + loops_adjusted.begin(), + loops_adjusted.end(), + ranges.begin(), + [](ForLoop* fl) { return fl->iter_domain(); }); + + std::vector indices(loops_adjusted.size()); + std::transform( + loops_adjusted.begin(), + loops_adjusted.end(), + indices.begin(), + [](ForLoop* fl) { + return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); + }); + + std::vector used_inds; + std::vector used_ranges; + bool unrolled = false; + for (decltype(loops_adjusted.size()) i{0}; i < loops_adjusted.size(); i++) { + if (ranges[i]->parallel_method() == ParallelType::Unroll) + unrolled = true; + if (!unrolled && producer->hasComputeAt() && + i < producer->getThisComputeAtAxis()) + continue; + if (producer->getMemoryType() == MemoryType::Shared && + ranges[i]->isBlockDim()) + continue; + if (producer->getMemoryType() == MemoryType::Local && ranges[i]->isThread()) + continue; + if (ranges[i]->isBroadcast()) + continue; + + used_inds.push_back(indices[i]); + used_ranges.push_back(ranges[i]); + } + + for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { + Val* ind = used_inds[i]; + for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) + ind = mul(ind, used_ranges[j]->extent()); + used_inds[i] = ind; + } + if (used_inds.size() == 0) + used_inds.push_back(new Int(0)); + + return new TensorIndex(producer, used_inds); +} + +TensorIndex* Index::getGlobalConsumerIndex( + TensorView* consumer, + const std::vector& loops) { + // If we're initializing a reduction buffer, we won't have the reduction + // loops. If we're actually performing the reduction, we will. + + std::vector indices(loops.size()); + std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) { + return fl->index(); + }); + + std::vector computed_inds = + IndexCompute::get(consumer->domain(), indices); + + auto root_dom = consumer->getRootDomain(); + TORCH_INTERNAL_ASSERT( + computed_inds.size() == root_dom.size(), + "Dimensionality error in code generator while computing indexing."); + + for (decltype(root_dom.size()) i{0}; i < root_dom.size(); i++) { + // Do this backwards so erase offset will be right + auto axis = root_dom.size() - i - 1; + if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast()) + computed_inds.erase(computed_inds.begin() + axis); + } + + std::vector strided_inds; + for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { + std::stringstream ss; + ss << "T" << consumer->name() << ".stride[" << i << "]"; + strided_inds.push_back( + mul(computed_inds[i], new NamedScalar(ss.str(), DataType::Int))); + } + + // Probably shouldn't ever hit this + if (strided_inds.size() == 0) + strided_inds.push_back(new Int(0)); + + return new TensorIndex(consumer, strided_inds); +} + +// Consumer index for either shared or local memory +TensorIndex* Index::getConsumerIndex_impl( + TensorView* consumer, + const std::vector& loops) { + // If we're initializing a reduction buffer, we won't have the reduction + // loops. If we're actually performing the reduction, we will. + + bool have_reduction_iters = loops.size() == consumer->nDims(); + + if (!have_reduction_iters) { + TORCH_INTERNAL_ASSERT( + // Init reduction space + loops.size() == consumer->domain()->noReductions().size(), + "Expected a tensor with ", + loops.size(), + " dimensions but got one with ", + consumer->domain()->noReductions().size()); + } else { + TORCH_INTERNAL_ASSERT( + // Calling the reduction op + loops.size() == consumer->nDims(), + "Expected a tensor with ", + loops.size(), + " dimensions but got one with ", + consumer->nDims()); + } + + std::vector ranges(loops.size()); + std::transform(loops.begin(), loops.end(), ranges.begin(), [](ForLoop* fl) { + return fl->iter_domain(); + }); + + std::vector indices(loops.size()); + std::transform(loops.begin(), loops.end(), indices.begin(), [](ForLoop* fl) { + return fl->iter_domain()->isBroadcast() ? new Int(0) : fl->index(); + }); + + std::vector used_inds; + std::vector used_ranges; + bool unrolled = false; + for (decltype(loops.size()) i{0}; i < loops.size(); i++) { + if (have_reduction_iters && consumer->axis(i)->isReduction()) + continue; + if (ranges[i]->parallel_method() == ParallelType::Unroll) + unrolled = true; + if (!unrolled && consumer->hasComputeAt() && + i < consumer->getThisComputeAtAxis()) + continue; + if (consumer->getMemoryType() == MemoryType::Shared && + ranges[i]->isBlockDim()) + continue; + if (consumer->getMemoryType() == MemoryType::Local && ranges[i]->isThread()) + continue; + if (ranges[i]->isBroadcast()) + continue; + + used_inds.push_back(indices[i]); + used_ranges.push_back(ranges[i]); + } + + for (decltype(used_inds.size()) i{0}; i < used_inds.size(); i++) { + Val* ind = used_inds[i]; + for (decltype(used_ranges.size()) j{i + 1}; j < used_ranges.size(); j++) + ind = mul(ind, used_ranges[j]->extent()); + used_inds[i] = ind; + } + + if (used_inds.size() == 0) + used_inds.push_back(new Int(0)); + + return new TensorIndex(consumer, used_inds); +} + +// Producer is the inputs of an expression +TensorIndex* Index::getProducerIndex( + TensorView* producer, + TensorView* consumer, + const std::vector& loops) { + TORCH_INTERNAL_ASSERT( + loops.size() == consumer->nDims() || + loops.size() == consumer->domain()->noReductions().size()); + + if (producer->getMemoryType() == MemoryType::Global) + return getGlobalProducerIndex(producer, consumer, loops); + return getProducerIndex_impl(producer, consumer, loops); +} + +// Consumer is the output of an expression +TensorIndex* Index::getConsumerIndex( + TensorView* consumer, + const std::vector& loops) { + TORCH_INTERNAL_ASSERT( + loops.size() == consumer->nDims() || + loops.size() == consumer->domain()->noReductions().size()); + + if (consumer->getMemoryType() == MemoryType::Global) + return getGlobalConsumerIndex(consumer, loops); + return getConsumerIndex_impl(consumer, loops); +} + +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/index_compute.h b/torch/csrc/jit/codegen/cuda/index_compute.h index cb11932f600fc..4523234e49f5d 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.h +++ b/torch/csrc/jit/codegen/cuda/index_compute.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include @@ -55,56 +55,50 @@ namespace torch { namespace jit { namespace fuser { -struct IndexCompute : public TransformIter { +struct IndexCompute : public BackwardVisitor { private: - // TensorDomain* replayBackward(Split*, TensorDomain*) override; - // TensorDomain* replayBackward(Merge*, TensorDomain*) override; + void handle(Split*) override; + void handle(Merge*) override; + void handle(Expr*) override; - // TensorDomain* runBackward(std::vector history); + // Otherwise warning on runBackward as it hides an overloaded virtual + // using TransformIter::runBackward; - // // Otherwise warning on runBackward as it hides an overloaded virtual - // function using TransformIter::runBackward; - - // IndexCompute(TensorDomain* td, std::vector _indices); - // std::vector indices; + IndexCompute(TensorDomain* td, const std::vector& _indices); + std::unordered_map index_map_; + std::vector indices_; public: static std::vector get( TensorDomain* td, - const std::vector& _indices) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + const std::vector& _indices); }; // Simple interface for IndexCompute -struct Index : public TransformIter { +struct Index { private: - // // Producer indexing if it's in shared or local memory - // static TensorIndex* getProducerIndex_impl( - // TensorView* producer, - // TensorView* consumer, - // const std::vector& loops); + // Producer indexing if it's in shared or local memory + static TensorIndex* getProducerIndex_impl( + TensorView* producer, + TensorView* consumer, + const std::vector& loops); - // // Consumer indexing if it's in shared or local memory - // static TensorIndex* getConsumerIndex_impl( - // TensorView* consumer, - // const std::vector& loops); + // Consumer indexing if it's in shared or local memory + static TensorIndex* getConsumerIndex_impl( + TensorView* consumer, + const std::vector& loops); public: // Producer if it's in global memory static TensorIndex* getGlobalProducerIndex( TensorView* producer, TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + const std::vector& loops); // Consumer indexing if it's in global memory static TensorIndex* getGlobalConsumerIndex( TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + const std::vector& loops); // Indexing functions // Consumer = Producer @@ -113,21 +107,15 @@ struct Index : public TransformIter { static TensorIndex* getProducerIndex( TensorView* producer, TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + const std::vector& loops); // Consumer index dispatch static TensorIndex* getConsumerIndex( TensorView* consumer, - const std::vector& loops) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + const std::vector& loops); // Will run inds through back prop index computation for tv - static TensorIndex* manualBackprop(TensorView tv, std::vector inds) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } + static TensorIndex* manualBackprop(TensorView tv, std::vector inds); }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index dcb5421a7cee6..17d431fa57d2f 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -396,9 +396,10 @@ bool TensorDomain::hasBroadcast() const { } bool TensorDomain::hasRFactor() const { - return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) { - return id->isRFactorProduct(); - }); + return std::any_of( + root_domain_.begin(), root_domain_.end(), [](IterDomain* id) { + return id->isRFactorProduct(); + }); } // i here is int, as we want to accept negative value and ::size_type can be a diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 1bd1d7e0c3fa5..2d32842f2ff69 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -7,6 +7,8 @@ namespace torch { namespace jit { namespace fuser { +/* ITER VISITOR */ + std::vector IterVisitor::next(Statement* statement) { if (statement->isVal()) return next(static_cast(statement)); @@ -18,6 +20,7 @@ std::vector IterVisitor::next(Statement* statement) { } std::vector IterVisitor::next(Val* v) { + FusionGuard::getCurFusion()->assertInFusion(v, "Cannot traverse val, "); if (FusionGuard::getCurFusion()->origin(v) != nullptr) return {FusionGuard::getCurFusion()->origin(v)}; return {}; @@ -160,7 +163,7 @@ struct Exprs : public IterVisitor { // expressions. struct Inputs : public IterVisitor { private: - std::set inputs; + std::unordered_set inputs; void handle(Val* val) { if (val->getOrigin() == nullptr) @@ -168,9 +171,9 @@ struct Inputs : public IterVisitor { } public: - static std::set getInputs(const std::vector& of) { + static std::unordered_set getInputs(const std::vector& of) { if (of.empty()) - return std::set(); + return std::unordered_set(); Inputs inps; inps.traverseFrom(of[0]->fusion(), of); return inps.inputs; @@ -178,10 +181,11 @@ struct Inputs : public IterVisitor { }; } // namespace -std::set IterVisitor::getTerminatingOutputs(Fusion* const fusion) { +std::unordered_set IterVisitor::getTerminatingOutputs( + Fusion* const fusion) { FusionGuard fg(fusion); - std::set used_vals; + std::unordered_set used_vals; for (auto expr : Exprs::getExprs( fusion, std::vector( @@ -190,7 +194,7 @@ std::set IterVisitor::getTerminatingOutputs(Fusion* const fusion) { used_vals.emplace(inp); } - std::set terminating_outputs; + std::unordered_set terminating_outputs; for (auto out : fusion->outputs()) if (used_vals.find(out) == used_vals.end()) terminating_outputs.emplace(out); @@ -198,19 +202,151 @@ std::set IterVisitor::getTerminatingOutputs(Fusion* const fusion) { return terminating_outputs; } -std::set IterVisitor::getInputsTo(const std::vector& vals) { +std::unordered_set IterVisitor::getInputsTo( + const std::vector& vals) { return Inputs::getInputs(vals); } +namespace { + +struct AllVals : public IterVisitor { + private: + std::unordered_set vals; + + void handle(Val* val) final { + vals.emplace(val); + } + + public: + static std::unordered_set get( + Fusion* fusion, + const std::vector& from) { + AllVals av; + av.traverseFrom(fusion, from, false); + return av.vals; + } +}; + +} // namespace + +/* BACKWARDS VISITOR */ + +std::vector BackwardVisitor::next(Statement* stmt) { + if (stmt->isVal()) + return next(static_cast(stmt)); + else if (stmt->isExpr()) + return next(static_cast(stmt)); + else + TORCH_INTERNAL_ASSERT( + false, "BackwardVisitor could not detect type in next_dispatch."); +} + +std::vector BackwardVisitor::next(Expr* expr) { + return std::vector( + expr->outputs().begin(), expr->outputs().end()); +} + +std::vector BackwardVisitor::next(Val* val) { + // Going to sort based on relative topological position + std::map exprs; + + for (auto expr : FusionGuard::getCurFusion()->uses(val)) + // Make sure it's an expr we can traverse + if (traversal_exprs_.find(expr) != traversal_exprs_.end()) + exprs[traversal_exprs_[expr]] = expr; + + std::vector next_stmts(exprs.size()); + std::transform( + exprs.begin(), + exprs.end(), + next_stmts.begin(), + [](std::pair pair) { return pair.second; }); + + return next_stmts; +} + +void BackwardVisitor::traverseFrom( + Fusion* const fusion, + const std::vector& from, + bool traverseAllPaths) { + FusionGuard fg(fusion); + + // Reset members + stmt_stack_.clear(); + traversal_exprs_.clear(); + + if (from.empty()) + return; + + auto vals = AllVals::get(fusion, from); + + auto exprs = Exprs::getExprs(fusion, from); + + { + size_t pos = 0; + for (auto expr : exprs) + traversal_exprs_[expr] = pos++; + } + + // All stmts we've called handle on + std::unordered_set visited_stmts_; + + for (auto traversal_pair : traversal_exprs_) + for (auto out : traversal_pair.first->outputs()) + TORCH_INTERNAL_ASSERT( + vals.find(out) != vals.end(), + "Invalid backward traversal found. Some output paths were not provided."); + + auto inputs = InputsOf::getInputsTo(from); + stmt_stack_.emplace_back(inputs.begin(), inputs.end()); + + // The rest is basically copy-pasted from IterVitor: + while (!stmt_stack_.empty()) { + auto next_stmts = next(stmt_stack_.back().back()); + + // Remove statements we already visited if we're not traversing all paths + if (!traverseAllPaths) + remove_visited(next_stmts, visited_stmts_); + + // Traverse down until we get to a leaf + while (!next_stmts.empty()) { + stmt_stack_.emplace_back(next_stmts.rbegin(), next_stmts.rend()); + next_stmts = next(stmt_stack_.back().back()); + // Remove statements we already visited if we're not traversing all paths + if (!traverseAllPaths) + remove_visited(next_stmts, visited_stmts_); + } + + // Traverse back up + // Mark visited + visited_stmts_.emplace(stmt_stack_.back().back()); + // Handle + handle(stmt_stack_.back().back()); + // Remove + stmt_stack_.back().pop_back(); + + while (!stmt_stack_.empty() && stmt_stack_.back().empty()) { + stmt_stack_.pop_back(); + if (!stmt_stack_.empty()) { + // Mark visited + visited_stmts_.emplace(stmt_stack_.back().back()); + // Handle + handle(stmt_stack_.back().back()); + // Remove + stmt_stack_.back().pop_back(); + } + } + } +} + /* DEPENDENCY CHECKING */ namespace { - // Looks for and returns struct DependencyChains : public IterVisitor { std::deque> dep_chains; bool is_dependency = false; - std::set dependencies_; + std::unordered_set dependencies_; void handle(Val* val) override { if (dependencies_.find(val) != dependencies_.end()) { @@ -238,7 +374,9 @@ struct DependencyChains : public IterVisitor { traverse(_dependency->fusion(), false); } - DependencyChains(std::set _dependencies, bool all_chains_ = false) + DependencyChains( + std::unordered_set _dependencies, + bool all_chains_ = false) : dependencies_(std::move(_dependencies)) { if (dependencies_.empty()) return; @@ -273,7 +411,7 @@ struct DependencyChains : public IterVisitor { } static std::deque> getDependencyChainsTo( - const std::set& dependencies) { + const std::unordered_set& dependencies) { DependencyChains dp(dependencies, true); if (dp.dep_chains.empty()) return std::deque>(); diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 24390721bc222..4e37404d37542 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -5,7 +5,7 @@ #include #include -#include +#include #include namespace torch { @@ -107,9 +107,83 @@ struct TORCH_CUDA_API IterVisitor : public OptOutDispatch { bool from_outputs_only = false, bool breadth_first = false); - static std::set getTerminatingOutputs(Fusion* const); + static std::unordered_set getTerminatingOutputs(Fusion* const); - static std::set getInputsTo(const std::vector& vals); + static std::unordered_set getInputsTo(const std::vector& vals); +}; + +/* + * Backward visitor IterVisitor calls handle in reverse order from outputs + * to inputs It would be really nice to unify this with IterVisitor, however, + * the challenge there is that we specify traversal from outputs towards inputs + * because it implicitly provides DCE. However, if users are not careful, they + * could miss necessary outputs to do a backward traversal. + * + * BackwardVisitor checks that all outputs of an Expr is visited before visiting + * the Expr. If we don't provide nodes to start from on all backward paths of + * those outputs we will never visit the Expr. + * + * The first step of BackwardVisitor is to make sure we've specified enough + * outputs to guarentee that we will traverse all outputs of all exprs during + * the backward traversal. + */ +struct TORCH_CUDA_API BackwardVisitor : public OptOutDispatch { + virtual ~BackwardVisitor() = default; + + BackwardVisitor() = default; + + BackwardVisitor(const BackwardVisitor& other) = default; + BackwardVisitor& operator=(const BackwardVisitor& other) = default; + + BackwardVisitor(BackwardVisitor&& other) = default; + BackwardVisitor& operator=(BackwardVisitor&& other) = default; + + // Functions return nodes in reverse order to be added to the to_visit queue + // These functions will start at outputs and propagate up through the DAG + // to inputs based on depth first traversal. Next could be called on a node + // multiple times. + virtual std::vector next(Statement* stmt); + + virtual std::vector next(Expr* expr); + + virtual std::vector next(Val* val); + + // This handle functions is called on every Statement* in topological order, + // starting from outputs to inputs. + virtual void handle(Statement* stmt) override { + OptOutDispatch::handle(stmt); + } + // This handle functions is called on every Expr* in topological order, + // starting from outputs to inputs. + virtual void handle(Expr* expr) override { + OptOutDispatch::handle(expr); + } + // This handle functions is called on every Val* in topological order, + // starting from outputs to inputs. + virtual void handle(Val* val) override { + OptOutDispatch::handle(val); + } + + // All exprs that need to be visited in this traversal. Labeled in topological + // order (size_t). + std::unordered_map traversal_exprs_; + + // The entire stack during traversal. stmt_stack.back().back() is the node + // that is being called in handle(). stmt_stack.back() contains siblings (not + // guarenteed to be all siblings throughout traversal). stmt_stack.front() + // contains the inputs we started with (not guarenteed to be all outputs + // throughout traversal). + std::deque> stmt_stack_; + + // Starts at nodes provided in from, traverses from these nodes to inputs. + // Calls handle on all Statement*s in topological sorted order. + // traverseAllPaths = false only call handle on each Statement* once + // traverseAllPaths = true traverses all paths from nodes in from to inputs. + // Handle on a Statement* for every path from "from" nodes, to inputs. + void traverseFrom( + Fusion* const fusion, + const std::vector& from, + bool traverseAllPaths = false); }; struct TORCH_CUDA_API DependencyCheck { diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 18e0d4b8b7532..73578a0ca2732 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -317,11 +317,8 @@ void GPULower::replaceSizes() { // We should just be able to replace sizes in place, but mutator is setup to // do that as it set up to replace vals in Exprs, but // IterDomain/TensorDomain are vals. - std::vector axis_map(new_domain->nDims()); - std::iota(axis_map.begin(), axis_map.end(), 0); - TORCH_INTERNAL_ASSERT(false, "NIY."); - // new_domain = TransformIter::replaySelf( - // new_domain, TransformIter::getHistory(old_domain), axis_map); + + new_domain = TransformReplay::fullSelfReplay(new_domain, old_domain); TORCH_INTERNAL_ASSERT( old_domain->nDims() == new_domain->nDims(), diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index f6812bc96ae0f..1334162f232f7 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -195,33 +195,15 @@ struct TORCH_CUDA_API ReplayTransformations : public IterVisitor { * these events (generate_record=true) you can then replay them on another * tensor domain. */ -struct TORCH_CUDA_API TransformIter : public IterVisitor { +struct TORCH_CUDA_API TransformBackwardIter : public IterVisitor { protected: - // virtual TensorDomain* replayBackward(Split*, TensorDomain*); - // virtual TensorDomain* replayBackward(Merge*, TensorDomain*); - - // // dispatch - // TensorDomain* replayBackward(Expr*, TensorDomain*); - - // // Iterates td's history starting with td, then origin(td), - // origin(origin(td)) - // // etc. Returns root TensorDomain once it iterates through history. If - // // generate_record=true It will record the history of td in record. - // Record is - // // order operations root->td. - // virtual TensorDomain* runBackward(TensorDomain*); - - // virtual TensorDomain* replay(Split*, TensorDomain*); - // virtual TensorDomain* replay(Merge*, TensorDomain*); - - // // dispatch - // virtual TensorDomain* replay(Expr*, TensorDomain*); - - // // Runs through operations in history and applies them to TD, runs exprs - // from - // // begining to end - // virtual TensorDomain* runReplay(TensorDomain*, const - // std::vector&); + virtual TensorDomain* replayBackward(Split*, TensorDomain*); + virtual TensorDomain* replayBackward(Merge*, TensorDomain*); + + // dispatch + TensorDomain* replayBackward(Expr*, TensorDomain*); + + virtual TensorDomain* runBackward(TensorDomain*); public: // Returns transformation exprs in forward order diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index b083b1e1d5614..35a858a1389fa 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -43,6 +43,53 @@ namespace fuser { // return replayed; // } +// Self replay. +TensorDomain* TransformReplay::fullSelfReplay( + TensorDomain* new_self_root, + TensorDomain* self) { + TORCH_INTERNAL_ASSERT( + new_self_root->nDims() == self->rootDomain().size(), + "Invalid number of IterDomains provided."); + + // Map for replay, should be pretty simple. + std::unordered_map axis_map; + { + size_t i = 0; + for (auto id : self->rootDomain()) { + TORCH_INTERNAL_ASSERT( + new_self_root->axis(i)->start() == id->start(), + "Replay does not support IterDomains that do not start at 0."); + + TORCH_INTERNAL_ASSERT( + new_self_root->axis(i)->parallel_method() == id->parallel_method() && + new_self_root->axis(i)->isReduction() == id->isReduction() && + new_self_root->axis(i)->isRFactorProduct() == + id->isRFactorProduct() && + new_self_root->axis(i)->isBroadcast() == id->isBroadcast(), + "Axes do not match for self replay."); + axis_map[id] = new_self_root->axis(i); + i++; + } + } + + // Replay producer dimensions. + ReplayTransformations replay(self->domain(), axis_map); + std::vector new_domain(self->nDims(), nullptr); + + { + size_t i = 0; + for (auto id : self->domain()) { + auto it = replay.getReplay().find(id); + TORCH_INTERNAL_ASSERT( + it != replay.getReplay().end(), + "Error during replay, didn't replay an axis."); + new_domain[i++] = (*it).second; + } + } + + return new TensorDomain(new_self_root->domain(), new_domain); +} + // Replay producer as consumer. TensorDomain* TransformReplay::replayPasC( TensorDomain* producer, @@ -69,7 +116,7 @@ TensorDomain* TransformReplay::replayPasC( } // Figure out all inputs required to generate the compute_at dimensions - std::set consumer_root_ids = IterVisitor::getInputsTo( + std::unordered_set consumer_root_ids = IterVisitor::getInputsTo( std::vector(consumer_ids.begin(), consumer_ids.end())); // Map of consumer_root_ids to related producer_ids @@ -79,7 +126,7 @@ TensorDomain* TransformReplay::replayPasC( std::vector consumer_root = consumer->rootDomain(); std::vector producer_root = producer->rootDomain(); // Track which root axes in producer we send to replay - std::set producer_mapped_roots; + std::unordered_set producer_mapped_roots; // Map related axes from producer and consumer roots. Make sure we go to the // end of both. { @@ -152,7 +199,7 @@ TensorDomain* TransformReplay::replayPasC( std::vector new_IDs; // Add axes in (1) - std::set used_IDs; + std::unordered_set used_IDs; for (auto c_id : consumer_ids) { auto it = replay_producer.getReplay().find(c_id); TORCH_INTERNAL_ASSERT( @@ -341,7 +388,7 @@ TensorDomain* TransformReplay::replayCasP( } // Figure out all inputs required to generate the compute_at dimensions - std::set producer_root_ids = IterVisitor::getInputsTo( + std::unordered_set producer_root_ids = IterVisitor::getInputsTo( std::vector(producer_ids.begin(), producer_ids.end())); // Map of producer_root_ids to related producer_ids @@ -352,7 +399,7 @@ TensorDomain* TransformReplay::replayCasP( std::vector producer_root = producer->rootDomain(); // Track which root axes in consumer we send to replay - std::set consumer_mapped_roots; + std::unordered_set consumer_mapped_roots; // Map related axes from producer and consumer roots. Make sure we go to the // end of both. { @@ -426,7 +473,7 @@ TensorDomain* TransformReplay::replayCasP( std::vector new_IDs; // Add axes in (1) - std::set used_IDs; + std::unordered_set used_IDs; for (auto p_id : producer_ids) { auto it = replay_consumer.getReplay().find(p_id); TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.h b/torch/csrc/jit/codegen/cuda/transform_replay.h index 66a4b7b6ed774..65118b73de636 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.h +++ b/torch/csrc/jit/codegen/cuda/transform_replay.h @@ -122,13 +122,6 @@ struct TensorView; struct TORCH_CUDA_API TransformReplay { private: public: - // Self replay. - static TensorDomain* fullSelfReplay( - TensorDomain* self, - TensorDomain* self_copy) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - // Replay producer as consumer. static TensorDomain* replayPasC( TensorDomain* producer, @@ -152,6 +145,11 @@ struct TORCH_CUDA_API TransformReplay { TensorView* consumer, TensorView* producer, int producer_compute_at_axis); + + // Self replay. + static TensorDomain* fullSelfReplay( + TensorDomain* new_self_root, + TensorDomain* self); }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 909b462e9c9fa..425568fdd4d71 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -7,224 +7,224 @@ namespace torch { namespace jit { namespace fuser { -TensorDomain* TransformRFactor::runReplay( - TensorDomain* orig_td, - std::vector axes) { - int ndims = (int)orig_td->nDims(); - - // Adjust and check provided axes - std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { - TORCH_CHECK( - i >= -ndims && i < ndims, - "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", - -ndims, - " to ", - ndims - 1); - return i < 0 ? i + ndims : i; - }); - - // remove duplicates, and put into a set for searching - std::set axes_set(axes.begin(), axes.end()); - - // Make a copy of orig_td as we're going to change its history: - bool found_rfactor = false; - std::vector domain_copy; - for (int i{0}; i < ndims; i++) { - IterDomain* orig_axis = orig_td->axis(i); - if (axes_set.find(i) != axes_set.end()) - TORCH_CHECK( - orig_axis->isReduction(), - "Tried to rFactor an axis that is not a reduction."); - - if (orig_axis->isReduction()) { - if (axes_set.find(i) == axes_set.end()) { - domain_copy.push_back(new IterDomain( - orig_axis->start(), - orig_axis->extent(), - orig_axis->parallel_method(), - false, - true)); - found_rfactor = true; - } else { - domain_copy.push_back(new IterDomain( - orig_axis->start(), - orig_axis->extent(), - orig_axis->parallel_method(), - true, - true)); - } - } else { - domain_copy.push_back(orig_td->axis(i)); - } - } - TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); - - // TD that we will actually modify, - TensorDomain* out_td = new TensorDomain(domain_copy); - - // Axis map to create history for non-rfactor axes - std::vector axis_map(ndims, -1); - std::vector orig_rfactor_axis_map(ndims, -1); - std::set rfactor_ids; - for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) - if (!out_td->axis(i)->isRFactorProduct()) { - axis_map[i] = i; - } else { - orig_rfactor_axis_map[i] = i; - } - - // Replay non-rfactor axes - auto running_td = TransformIter::replayBackward( - out_td, TransformIter::getHistory(orig_td), axis_map); - - // running_td has iteration domains on the right, but to find a valid rfactor - // root, we want those to be on the right. If we continued to replay backward - // we likely won't have a valid rfactor root. Lets manually insert a so we - // have a valid rfactor root. - - std::vector new2old(running_td->nDims()); - { - int running_pos = 0; - for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) - if (!running_td->axis(i)->isRFactorProduct()) - new2old[i] = running_pos++; - - for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) - if (running_td->axis(i)->isRFactorProduct()) - new2old[i] = running_pos++; - } - - // how do we find axes - // Need axis map from rfactor axes in running_td to corresponding axes in - // orig_td orig_rfactor_axis_map goes from orig_td to out_td we want it to - // go from orig_td to running_td - - // Go from IterDomain to its position in running_td - std::unordered_map new_pos; - for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) { - new_pos[running_td->axis(i)] = i; - } - - for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) - if (orig_rfactor_axis_map[i] != -1) { - // int orig_td_pos = i; - int out_td_pos = orig_rfactor_axis_map[i]; - TORCH_INTERNAL_ASSERT( - new_pos.find(out_td->axis(out_td_pos)) != new_pos.end(), - "Error aligning axes in rfactor first TD replay."); - int running_td_pos = new_pos[out_td->axis(out_td_pos)]; - orig_rfactor_axis_map[i] = running_td_pos; - } - - TransformIter::replayBackward( - running_td, TransformIter::getHistory(orig_td), orig_rfactor_axis_map); - - return out_td; -} - -TensorDomain* TransformRFactor::runReplay2( - TensorDomain* in_td, - std::vector axes) { - int ndims = (int)in_td->nDims(); - - // Adjust and check provided axes - std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { - TORCH_CHECK( - i >= -ndims && i < ndims, - "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", - -ndims, - " to ", - ndims - 1); - return i < 0 ? i + ndims : i; - }); - - // remove duplicates, and put into a set for searching - std::set axes_set(axes.begin(), axes.end()); - - bool found_rfactor = false; - // Axes marked as rfactor, these will be removed from this domain - std::vector rfactor_axes(in_td->nDims(), false); - for (int i{0}; i < ndims; i++) { - bool in_set = axes_set.find(i) != axes_set.end(); - IterDomain* orig_axis = in_td->axis(i); - - if (in_set) { - TORCH_CHECK( - orig_axis->isReduction(), - "Tried to rFactor an axis that is not a reduction."); - rfactor_axes[i] = true; - found_rfactor = true; - } - } - - TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); - auto root_rfactor_axes = TransformIter::getRootInfluence(in_td, rfactor_axes); - - // Root axes involved in rfactor, these axes should not be replayed, they need - // to be either removed completely, or part of the root domain - auto root_dom = TransformIter::getRoot(in_td); - TORCH_INTERNAL_ASSERT( - root_rfactor_axes.size() == root_dom->nDims(), - "Error backpropagating influence of rfactor."); - - // Forward propagate influence back to the end we want to mark everything - // that's part of the rfactor - rfactor_axes = TransformIter::replayInfluence( - TransformIter::getHistory(in_td), root_rfactor_axes); - - // Axes part of rfactor we need to keep - std::vector rfactor_axes_keep; - - for (int i{0}; i < ndims; i++) { - if (rfactor_axes[i] && axes_set.find(i) == axes_set.end()) { - TORCH_INTERNAL_ASSERT( - in_td->axis(i)->isReduction(), - "Error occured when tracking rfactor axes."); - rfactor_axes_keep.push_back(in_td->axis(i)); - } - } - - int root_ndims = (int)root_dom->nDims(); - std::vector domain_copy; - // These are the axes that are not involved in the rfactor. - for (int i{0}; i < root_ndims; i++) { - if (!root_rfactor_axes[i]) { - domain_copy.push_back(root_dom->axis(i)); - } - } - - TORCH_INTERNAL_ASSERT( - domain_copy.size() < root_dom->nDims(), - "Error during rfactor, didn't get any rfactor axes."); - - // Setup axis map before we add back in the rfactor_axes - std::vector replay_axis_map(root_dom->nDims(), -1); - { - decltype(domain_copy.size()) it = 0; - decltype(root_dom->nDims()) ir = 0; - while (it < domain_copy.size() && ir < root_dom->nDims()) { - if (root_rfactor_axes[ir]) { - ir++; - } else { - replay_axis_map[ir++] = it++; - } - } - TORCH_INTERNAL_ASSERT( - it == domain_copy.size(), - "Error during rfactor, missed an unmodified root domain."); - } - - // Push back the rfactor axes we need to keep - domain_copy.insert( - domain_copy.end(), rfactor_axes_keep.begin(), rfactor_axes_keep.end()); - - // TD that we will actually modify - TensorDomain* replay_root_td = new TensorDomain(domain_copy); - auto td = TransformIter::replay( - replay_root_td, TransformIter::getHistory(in_td), replay_axis_map); - - return td; -} +// TensorDomain* TransformRFactor::runReplay( +// TensorDomain* orig_td, +// std::vector axes) { +// int ndims = (int)orig_td->nDims(); + +// // Adjust and check provided axes +// std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { +// TORCH_CHECK( +// i >= -ndims && i < ndims, +// "Rfactor replay recieved an axis outside the number of dims in the +// tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); +// return i < 0 ? i + ndims : i; +// }); + +// // remove duplicates, and put into a set for searching +// std::set axes_set(axes.begin(), axes.end()); + +// // Make a copy of orig_td as we're going to change its history: +// bool found_rfactor = false; +// std::vector domain_copy; +// for (int i{0}; i < ndims; i++) { +// IterDomain* orig_axis = orig_td->axis(i); +// if (axes_set.find(i) != axes_set.end()) +// TORCH_CHECK( +// orig_axis->isReduction(), +// "Tried to rFactor an axis that is not a reduction."); + +// if (orig_axis->isReduction()) { +// if (axes_set.find(i) == axes_set.end()) { +// domain_copy.push_back(new IterDomain( +// orig_axis->start(), +// orig_axis->extent(), +// orig_axis->parallel_method(), +// false, +// true)); +// found_rfactor = true; +// } else { +// domain_copy.push_back(new IterDomain( +// orig_axis->start(), +// orig_axis->extent(), +// orig_axis->parallel_method(), +// true, +// true)); +// } +// } else { +// domain_copy.push_back(orig_td->axis(i)); +// } +// } +// TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); + +// // TD that we will actually modify, +// TensorDomain* out_td = new TensorDomain(domain_copy); + +// // Axis map to create history for non-rfactor axes +// std::vector axis_map(ndims, -1); +// std::vector orig_rfactor_axis_map(ndims, -1); +// std::set rfactor_ids; +// for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) +// if (!out_td->axis(i)->isRFactorProduct()) { +// axis_map[i] = i; +// } else { +// orig_rfactor_axis_map[i] = i; +// } + +// // Replay non-rfactor axes +// auto running_td = TransformIter::replayBackward( +// out_td, TransformIter::getHistory(orig_td), axis_map); + +// // running_td has iteration domains on the right, but to find a valid +// rfactor +// // root, we want those to be on the right. If we continued to replay +// backward +// // we likely won't have a valid rfactor root. Lets manually insert a so we +// // have a valid rfactor root. + +// std::vector new2old(running_td->nDims()); +// { +// int running_pos = 0; +// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) +// if (!running_td->axis(i)->isRFactorProduct()) +// new2old[i] = running_pos++; + +// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) +// if (running_td->axis(i)->isRFactorProduct()) +// new2old[i] = running_pos++; +// } + +// // how do we find axes +// // Need axis map from rfactor axes in running_td to corresponding axes in +// // orig_td orig_rfactor_axis_map goes from orig_td to out_td we want it to +// // go from orig_td to running_td + +// // Go from IterDomain to its position in running_td +// std::unordered_map new_pos; +// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) { +// new_pos[running_td->axis(i)] = i; +// } + +// for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) +// if (orig_rfactor_axis_map[i] != -1) { +// // int orig_td_pos = i; +// int out_td_pos = orig_rfactor_axis_map[i]; +// TORCH_INTERNAL_ASSERT( +// new_pos.find(out_td->axis(out_td_pos)) != new_pos.end(), +// "Error aligning axes in rfactor first TD replay."); +// int running_td_pos = new_pos[out_td->axis(out_td_pos)]; +// orig_rfactor_axis_map[i] = running_td_pos; +// } + +// TransformIter::replayBackward( +// running_td, TransformIter::getHistory(orig_td), orig_rfactor_axis_map); + +// return out_td; +// } + +// TensorDomain* TransformRFactor::runReplay2( +// TensorDomain* in_td, +// std::vector axes) { +// int ndims = (int)in_td->nDims(); + +// // Adjust and check provided axes +// std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { +// TORCH_CHECK( +// i >= -ndims && i < ndims, +// "Rfactor replay recieved an axis outside the number of dims in the +// tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); +// return i < 0 ? i + ndims : i; +// }); + +// // remove duplicates, and put into a set for searching +// std::set axes_set(axes.begin(), axes.end()); + +// bool found_rfactor = false; +// // Axes marked as rfactor, these will be removed from this domain +// std::vector rfactor_axes(in_td->nDims(), false); +// for (int i{0}; i < ndims; i++) { +// bool in_set = axes_set.find(i) != axes_set.end(); +// IterDomain* orig_axis = in_td->axis(i); + +// if (in_set) { +// TORCH_CHECK( +// orig_axis->isReduction(), +// "Tried to rFactor an axis that is not a reduction."); +// rfactor_axes[i] = true; +// found_rfactor = true; +// } +// } + +// TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); +// auto root_rfactor_axes = TransformIter::getRootInfluence(in_td, +// rfactor_axes); + +// // Root axes involved in rfactor, these axes should not be replayed, they +// need +// // to be either removed completely, or part of the root domain +// auto root_dom = TransformIter::getRoot(in_td); +// TORCH_INTERNAL_ASSERT( +// root_rfactor_axes.size() == root_dom->nDims(), +// "Error backpropagating influence of rfactor."); + +// // Forward propagate influence back to the end we want to mark everything +// // that's part of the rfactor +// rfactor_axes = TransformIter::replayInfluence( +// TransformIter::getHistory(in_td), root_rfactor_axes); + +// // Axes part of rfactor we need to keep +// std::vector rfactor_axes_keep; + +// for (int i{0}; i < ndims; i++) { +// if (rfactor_axes[i] && axes_set.find(i) == axes_set.end()) { +// TORCH_INTERNAL_ASSERT( +// in_td->axis(i)->isReduction(), +// "Error occured when tracking rfactor axes."); +// rfactor_axes_keep.push_back(in_td->axis(i)); +// } +// } + +// int root_ndims = (int)root_dom->nDims(); +// std::vector domain_copy; +// // These are the axes that are not involved in the rfactor. +// for (int i{0}; i < root_ndims; i++) { +// if (!root_rfactor_axes[i]) { +// domain_copy.push_back(root_dom->axis(i)); +// } +// } + +// TORCH_INTERNAL_ASSERT( +// domain_copy.size() < root_dom->nDims(), +// "Error during rfactor, didn't get any rfactor axes."); + +// // Setup axis map before we add back in the rfactor_axes +// std::vector replay_axis_map(root_dom->nDims(), -1); +// { +// decltype(domain_copy.size()) it = 0; +// decltype(root_dom->nDims()) ir = 0; +// while (it < domain_copy.size() && ir < root_dom->nDims()) { +// if (root_rfactor_axes[ir]) { +// ir++; +// } else { +// replay_axis_map[ir++] = it++; +// } +// } +// TORCH_INTERNAL_ASSERT( +// it == domain_copy.size(), +// "Error during rfactor, missed an unmodified root domain."); +// } + +// // Push back the rfactor axes we need to keep +// domain_copy.insert( +// domain_copy.end(), rfactor_axes_keep.begin(), rfactor_axes_keep.end()); + +// // TD that we will actually modify +// TensorDomain* replay_root_td = new TensorDomain(domain_copy); +// auto td = TransformIter::replay( +// replay_root_td, TransformIter::getHistory(in_td), replay_axis_map); + +// return td; +// } } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h index b0482129e94b5..64e3e6b8b0152 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h @@ -18,8 +18,12 @@ struct TORCH_CUDA_API TransformRFactor { public: // Create a copy of td, change its history by presrving axes so they appear in // the root domain - static TensorDomain* runReplay(TensorDomain*, std::vector axes); - static TensorDomain* runReplay2(TensorDomain*, std::vector axes); + static TensorDomain* runReplay(TensorDomain*, std::vector axes) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); + } + static TensorDomain* runReplay2(TensorDomain*, std::vector axes) { + TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); + } }; } // namespace fuser From d472f4315568c2942e73755be2bb0ddfda51ffc5 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Wed, 3 Jun 2020 11:39:35 -0400 Subject: [PATCH 8/8] Get refactor back to original state with passing tests. --- test/cpp/jit/test_gpu.cpp | 488 +++++++------- torch/csrc/jit/codegen/cuda/fusion.cpp | 4 +- torch/csrc/jit/codegen/cuda/index_compute.cpp | 20 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp | 14 +- torch/csrc/jit/codegen/cuda/ir_base_nodes.h | 10 +- .../jit/codegen/cuda/ir_interface_nodes.h | 3 - .../csrc/jit/codegen/cuda/ir_internal_nodes.h | 26 +- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 188 +++--- torch/csrc/jit/codegen/cuda/iter_visitor.h | 1 + torch/csrc/jit/codegen/cuda/lower2device.cpp | 1 - torch/csrc/jit/codegen/cuda/lower_loops.cpp | 1 - .../jit/codegen/cuda/predicate_compute.cpp | 9 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 13 +- .../csrc/jit/codegen/cuda/transform_iter.cpp | 194 +++++- torch/csrc/jit/codegen/cuda/transform_iter.h | 261 ++------ .../jit/codegen/cuda/transform_replay.cpp | 575 +++++++---------- .../codegen/cuda/transform_replay_rfactor.h | 229 +++++++ .../jit/codegen/cuda/transform_rfactor.cpp | 601 +++++++++++------- .../csrc/jit/codegen/cuda/transform_rfactor.h | 9 +- 19 files changed, 1495 insertions(+), 1152 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/transform_replay_rfactor.h diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index e9013284fa06f..a52e90fa29af0 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -413,8 +413,6 @@ void testGPU_FusionTVReorder() { Fusion fusion; FusionGuard fg(&fusion); - TensorView* dummyTensor = makeDummyTensor(3); - std::unordered_map shift_right{{-1, 0}}; std::unordered_map shift_left{{0, -1}}; @@ -422,28 +420,40 @@ void testGPU_FusionTVReorder() { std::unordered_map shift_left_2{{0, -1}, {1, 0}, {2, 1}}; std::unordered_map swap{{0, 2}, {2, 0}}; - TensorView* ref = dummyTensor->clone(); - TensorView* tv = dummyTensor->clone(); - TensorView* s_left1 = tv->reorder(shift_left); - for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i)->sameAs(s_left1->axis(i - 1))); + auto tv = makeDummyTensor(3); + std::vector ref; + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); - tv = dummyTensor->clone(); - TensorView* s_left2 = tv->reorder(shift_left); + tv->reorder(shift_left); for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i)->sameAs(s_left2->axis(i - 1))); + TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); - tv = dummyTensor->clone(); - TensorView* s_right = tv->reorder(shift_right); - for (int i = 0; i < (int)tv->nDims(); i++) - TORCH_CHECK(ref->axis(i - 1)->sameAs(s_right->axis(i))); + tv = makeDummyTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); - tv = dummyTensor->clone(); - TensorView* rswap = tv->reorder(swap); - TORCH_CHECK(ref->axis(0)->sameAs(rswap->axis(2))); - TORCH_CHECK(ref->axis(2)->sameAs(rswap->axis(0))); - TORCH_CHECK(ref->axis(1)->sameAs(rswap->axis(1))); + tv->reorder(shift_left); + for (int i = 0; i < (int)tv->nDims(); i++) + TORCH_CHECK(ref[i]->sameAs(tv->axis(i - 1))); + + tv = makeDummyTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + + tv->reorder(shift_right); + TORCH_CHECK(ref[ref.size() - 1]->sameAs(tv->axis(0))); + for (int i = 1; i < (int)tv->nDims(); i++) + TORCH_CHECK(ref[i - 1]->sameAs(tv->axis(i))); + + tv = makeDummyTensor(3); + ref = std::vector( + tv->domain()->domain().begin(), tv->domain()->domain().end()); + tv->reorder(swap); + TORCH_CHECK(ref[0]->sameAs(tv->axis(2))); + TORCH_CHECK(ref[2]->sameAs(tv->axis(0))); + TORCH_CHECK(ref[1]->sameAs(tv->axis(1))); } void testGPU_FusionEquality() { @@ -513,6 +523,76 @@ void testGPU_FusionReplaceAll() { TORCH_CHECK(static_cast(bop->lhs())->sameAs(new Float{2.f})); } +void testGPU_FusionDependency() { + Fusion fusion; + FusionGuard fg(&fusion); + + Float* f0 = new Float(0.f); + Float* f1 = new Float(1.f); + auto f2 = add(f0, f1); + + auto f3 = add(f2, f2); + + Float* f4 = new Float(4.f); + Float* f5 = new Float(5.f); + auto f6 = add(f4, f5); + + Float* f7 = new Float(7.f); + Float* f8 = new Float(8.f); + auto f9 = add(f7, f8); + + auto f10 = add(f6, f9); + + auto f11 = add(f3, f10); + + TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f1, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f3, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f6, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f9, f11)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f2)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f3)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f4, f6)); + TORCH_CHECK(DependencyCheck::isDependencyOf(f8, f10)); + + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f0)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f1)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f2)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f3)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f4)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f5)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f2, f0)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f3, f2)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f6, f4)); + TORCH_CHECK(!DependencyCheck::isDependencyOf(f10, f8)); + + auto dep_chain = DependencyCheck::getSingleDependencyChain(f0, f11); + TORCH_CHECK(dep_chain.back() == f11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == f3); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == f2); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(f6, f11); + TORCH_CHECK(dep_chain.back() == f11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == f10); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(f4, f11); + TORCH_CHECK(dep_chain.back() == f11); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == f10); + dep_chain.pop_back(); + TORCH_CHECK(dep_chain.back() == f6); + dep_chain.pop_back(); + + dep_chain = DependencyCheck::getSingleDependencyChain(f11, f2); + TORCH_CHECK(dep_chain.empty()); +} + void testGPU_FusionParser() { auto g = std::make_shared(); const auto graph0_string = R"IR( @@ -549,32 +629,32 @@ void testGPU_FusionParser() { ref << "__global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){\n" << " float T2[4];\n" << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " for(size_t i108 = 0; i108 < 4; ++i108 ) {\n" - << " T2[ i108 ]\n" - << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i108 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" - << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i108 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" + << " for(size_t i60 = 0; i60 < 4; ++i60 ) {\n" + << " T2[ i60 ]\n" + << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" + << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" << " }\n" << " } else { \n" - << " for(size_t i108 = 0; i108 < 4; ++i108 ) {\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i108 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " T2[ i108 ]\n" - << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i108 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" - << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i108 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" + << " for(size_t i60 = 0; i60 < 4; ++i60 ) {\n" + << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" + << " T2[ i60 ]\n" + << " = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]\n" + << " * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i60 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];\n" << " }\n" << " }\n" << " }\n" << " if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " for(size_t i109 = 0; i109 < 4; ++i109 ) {\n" - << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i109 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" - << " = T2[ i109 ]\n" - << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i109 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" + << " for(size_t i61 = 0; i61 < 4; ++i61 ) {\n" + << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" + << " = T2[ i61 ]\n" + << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" << " }\n" << " } else { \n" - << " for(size_t i109 = 0; i109 < 4; ++i109 ) {\n" - << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i109 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" - << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i109 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" - << " = T2[ i109 ]\n" - << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i109 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" + << " for(size_t i61 = 0; i61 < 4; ++i61 ) {\n" + << " if ( ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { \n" + << " T3[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]\n" + << " = T2[ i61 ]\n" + << " * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i61 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];\n" << " }\n" << " }\n" << " }\n" @@ -594,74 +674,39 @@ void testGPU_FusionParser() { } } -void testGPU_FusionDependency() { +void testGPU_FusionForLoop() { Fusion fusion; FusionGuard fg(&fusion); - Float* f0 = new Float(0.f); - Float* f1 = new Float(1.f); - auto f2 = add(f0, f1); - - auto f3 = add(f2, f2); - - Float* f4 = new Float(4.f); - Float* f5 = new Float(5.f); - auto f6 = add(f4, f5); - - Float* f7 = new Float(7.f); - Float* f8 = new Float(8.f); - auto f9 = add(f7, f8); - - auto f10 = add(f6, f9); - - auto f11 = add(f3, f10); - - TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f1, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f3, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f6, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f9, f11)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f0, f2)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f2, f3)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f4, f6)); - TORCH_CHECK(DependencyCheck::isDependencyOf(f8, f10)); + const auto TV0 = new TensorView( + new TensorDomain({new IterDomain(new Int(0), new Int(16))}), + DataType::Float); + const auto TV1 = new TensorView( + new TensorDomain({new IterDomain(new Int(0), new Int(16))}), + DataType::Float); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f0)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f1)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f2)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f3)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f4)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f11, f5)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f2, f0)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f3, f2)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f6, f4)); - TORCH_CHECK(!DependencyCheck::isDependencyOf(f10, f8)); + fusion.addInput(TV0); + fusion.addInput(TV1); - auto dep_chain = DependencyCheck::getSingleDependencyChain(f0, f11); - TORCH_CHECK(dep_chain.back() == f11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == f3); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == f2); - dep_chain.pop_back(); + auto ID0 = new IterDomain(new Int(0), new Int(8)); - dep_chain = DependencyCheck::getSingleDependencyChain(f6, f11); - TORCH_CHECK(dep_chain.back() == f11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == f10); - dep_chain.pop_back(); + TensorView* TV2 = static_cast(add(TV0, TV1)); + BinaryOp* op = static_cast(TV2->getOrigin()); + fusion.addOutput(TV2); - dep_chain = DependencyCheck::getSingleDependencyChain(f4, f11); - TORCH_CHECK(dep_chain.back() == f11); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == f10); - dep_chain.pop_back(); - TORCH_CHECK(dep_chain.back() == f6); - dep_chain.pop_back(); + ForLoop* fl = new ForLoop(new Int(), ID0, {op}); + std::stringstream result; + std::stringstream ref; + result << fl; + ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; - dep_chain = DependencyCheck::getSingleDependencyChain(f11, f2); - TORCH_CHECK(dep_chain.empty()); + if (result.str().compare(ref.str()) == 0) { + std::stringstream err_msg; + err_msg << "ForLoop printing has changed or something has gone wrong. " + << result.str() << "\n does not match reference: " << ref.str() + << std::endl; + TORCH_CHECK(false, err_msg.str()); + } } void testGPU_FusionCodeGen() { @@ -1308,8 +1353,8 @@ void testGPU_FusionLoopUnroll() { FusionGuard fg(&fusion); // Set up your input tensor views - TensorView* tv0 = makeDummyTensor(1); - TensorView* tv1 = makeDummyTensor(1); + TensorView* tv0 = makeDummyTensor(3); + TensorView* tv1 = makeDummyTensor(3); // Register your inputs fusion.addInput(tv0); @@ -1325,6 +1370,9 @@ void testGPU_FusionLoopUnroll() { int block_size = 16; + tv3->merge(0, 1); + tv3->merge(0, 1); + tv3->split(0, block_size); tv3->split(0, 4); @@ -1340,7 +1388,10 @@ void testGPU_FusionLoopUnroll() { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(0)->parallelize(ParallelType::BIDx); - int inp_size = 129; + int inp_size = 129 * 13 * 3; + + // GPULower lower(&fusion); + // lower.printKernel(std::cout); torch::jit::fuser::cuda::CudaKernel prog; prog.device_ = 0; @@ -1349,52 +1400,15 @@ void testGPU_FusionLoopUnroll() { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor input1 = at::ones({inp_size}, options); - at::Tensor input2 = at::ones_like(input1); + at::Tensor input0 = at::rand({129, 13, 3}, options); + at::Tensor input1 = at::rand({129, 13, 3}, options); at::Tensor output = at::empty_like(input1); torch::jit::fuser::cuda::compileKernel(fusion, &prog); - torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output}); + torch::jit::fuser::cuda::runTestKernel(&prog, {input0, input1}, {output}); - at::Tensor check = at::full({inp_size}, 4, options); - - TORCH_CHECK(output.equal(check)); -} - -void testGPU_FusionForLoop() { - Fusion fusion; - FusionGuard fg(&fusion); - - const auto TV0 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - const auto TV1 = new TensorView( - new TensorDomain({new IterDomain(new Int(0), new Int(16))}), - DataType::Float); - - fusion.addInput(TV0); - fusion.addInput(TV1); - - auto ID0 = new IterDomain(new Int(0), new Int(8)); - - TensorView* TV2 = static_cast(add(TV0, TV1)); - BinaryOp* op = static_cast(TV2->getOrigin()); - fusion.addOutput(TV2); - - ForLoop* fl = new ForLoop(new Int(), ID0, {op}); - std::stringstream result; - std::stringstream ref; - result << fl; - ref << "for(size_t i3{0}; i3 < iS{8}; ++i3 ) {\nT2[ iS{16} ] = T0[ iS{16} ] + T1[ iS{16} ]\n}"; - - if (result.str().compare(ref.str()) == 0) { - std::stringstream err_msg; - err_msg << "ForLoop printing has changed or something has gone wrong. " - << result.str() << "\n does not match reference: " << ref.str() - << std::endl; - TORCH_CHECK(false, err_msg.str()); - } + TORCH_CHECK(output.equal(input0.add(input1.add(2.0)))); } /* @@ -1903,94 +1917,92 @@ void testGPU_FusionCastOps() { // We want split/merge/reorder all tested both on and off rfactor domains, also // want compute at into the rfactor domain, and into its consumer void testGPU_FusionRFactorReplay() { - // Fusion fusion; - // FusionGuard fg(&fusion); - - // // Set up your input tensor views - // TensorView* tv0 = makeDummyTensor(2); - - // // Register your inputs - // fusion.addInput(tv0); - - // // Do math with it, it returns a `Val*` but can be static_casted back to - // // TensorView - // TensorView* tv1 = static_cast(sum(tv0, {1})); - // // tv1[I0, R1] - // tv1->split(0, 32); - // // tv1[I0o, I0i{32}, R1] - // tv1->split(0, 16); - // // tv1[I0oo, I0oi{16}, I0i{32}, R1] - // tv1->split(-1, 8); - // // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}] - // tv1->split(-2, 4); - // // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}] - - // tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}}); - // // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}] - - // tv1->merge(0); - // tv1->merge(-2); - - // // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}] - // TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0}); - // TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), - // {0}); - // // new_domain[R(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}] - // // new_domain2[ I0oi{16}, , I0oo*I0i{32}, - // R1oi{4}] - - // // Move rfactor axis to end, keep iter rfactor axis - // auto reordered_new_domain = new_domain->reorder({{0, -1}, {2, 2}}); - // // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, - // R(R1oo*R1i{8})rf] + Fusion fusion; + FusionGuard fg(&fusion); + + // Set up your input tensor views + TensorView* tv0 = makeDummyTensor(2); - // TensorDomain* casp = - // TransformReplay::replayCasP(new_domain2, reordered_new_domain, 2); + // Register your inputs + fusion.addInput(tv0); + + // Do math with it, it returns a `Val*` but can be static_casted back to + // TensorView + TensorView* tv1 = static_cast(sum(tv0, {1})); + // tv1[I0, R1] + tv1->split(0, 32); + // tv1[I0o, I0i{32}, R1] + tv1->split(0, 16); + // tv1[I0oo, I0oi{16}, I0i{32}, R1] + tv1->split(-1, 8); + // tv1[I0oo, I0oi{16}, I0i{32}, R1o, R1i{8}] + tv1->split(-2, 4); + // tv1[I0oo, I0oi{16}, I0i{32}, R1oo, R1oi{4}, R1i{8}] + tv1->reorder({{0, -2}, {2, -1}, {-3, 0}, {-1, 1}}); + // tv1[R1oo, R1i{8}, I0oi{16}, R1oi{4}, I0oo, I0i{32}] + + tv1->merge(0); + tv1->merge(-2); + + // tv1[R1oo*R1i{8}, I0oi{16}, R1oi{4}, I0oo*I0i{32}] + TensorDomain* new_domain = TransformRFactor::runReplay(tv1->domain(), {0}); + // new_domain[r(R1oo*R1i{8})rf, I0oi{16}, ir1oi{4}rf, I0oo*I0i{32}] + + TensorDomain* new_domain2 = TransformRFactor::runReplay2(tv1->domain(), {0}); + // new_domain2[ I0oi{16}, , I0oo*I0i{32}, R1oi{4}] + + // Move rfactor axis to end, keep iter rfactor axis + new_domain->reorder({{0, -1}, {2, 2}}); + + // Replay casp, replay new_domain2 as new_domain + // reordered_new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] + TensorDomain* casp = TransformReplay::replayCasP(new_domain2, new_domain, 2); // // new_domain[I0oi{16}, I0oo*I0i{32}, ir1oi{4}rf, R(R1oo*R1i{8})rf] - // // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] + // // casp[I0oi{16}, I0oo*I0i{32}, R1oi{4}] - // casp = casp->split(1, 2); - // // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}] + casp->split(1, 2); + // // casp [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4} ] // // new_domain[I0oi{16}, I0oo*I0i{32} , ir1oi{4}rf, - // // R(R1oo*R1i{8})rf] - // TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2); + // R(R1oo*R1i{8})rf] + + TensorDomain* pasc = TransformReplay::replayPasC(new_domain, casp, 2); // // pasc [I0oi{16}, (I0oo*I0i{32})o, I(Ioo*I0i)i{2}, ir1oi{4}rf, - // // R(R1oo*R1i{8})rf] - - // TORCH_CHECK( - // new_domain->nDims() - 1 == new_domain2->nDims(), - // casp->nDims() == new_domain2->nDims() + 1, - // pasc->nDims() == new_domain->nDims() + 1, - // "Error in rfactor, number of dimensions is not correct."); - - // TORCH_CHECK( - // !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) && - // !new_domain->sameAs(new_domain2) && - // !tv1->domain()->sameAs(new_domain) && - // !tv1->domain()->sameAs(new_domain2), - // "Error in rfactor, number of dimensions is not correct."); - - // auto dom = new_domain->rootDomain()->domain(); - // TORCH_CHECK( - // !new_domain->rootDomain()->axis(0)->isReduction() && - // std::any_of( - // dom.begin(), - // dom.end(), - // [](IterDomain* id) { return id->isReduction(); }) && - // std::any_of( - // dom.begin(), - // dom.end(), - // [](IterDomain* id) { return id->isRFactorProduct(); }), - // "Error in rFactor, there seems to be something wrong in root domain."); - - // auto dom2 = new_domain2->rootDomain()->domain(); - // TORCH_CHECK( - // !new_domain2->rootDomain()->axis(0)->isReduction() && - // std::any_of( - // dom2.begin(), - // dom2.end(), - // [](IterDomain* id) { return id->isReduction(); }), - // "Error in rFactor, there seems to be something wrong in root domain."); + // R(R1oo*R1i{8})rf] + + TORCH_CHECK( + new_domain->nDims() - 1 == new_domain2->nDims(), + casp->nDims() == new_domain2->nDims() + 1, + pasc->nDims() == new_domain->nDims() + 1, + "Error in rfactor, number of dimensions is not correct."); + + TORCH_CHECK( + !casp->sameAs(new_domain2) && !pasc->sameAs(new_domain) && + !new_domain->sameAs(new_domain2) && + !tv1->domain()->sameAs(new_domain) && + !tv1->domain()->sameAs(new_domain2), + "Error in rfactor, number of dimensions is not correct."); + + auto dom = new_domain->rootDomain(); + TORCH_CHECK( + !dom[0]->isReduction() && + std::any_of( + dom.begin(), + dom.end(), + [](IterDomain* id) { return id->isReduction(); }) && + std::any_of( + dom.begin(), + dom.end(), + [](IterDomain* id) { return id->isRFactorProduct(); }), + "Error in rFactor, there seems to be something wrong in root domain."); + + auto dom2 = new_domain2->rootDomain(); + TORCH_CHECK( + !dom2[0]->isReduction() && + std::any_of( + dom2.begin(), + dom2.end(), + [](IterDomain* id) { return id->isReduction(); }), + "Error in rFactor, there seems to be something wrong in root domain."); } // Start off simple, block on the outer dim @@ -2033,14 +2045,14 @@ void testGPU_FusionReduction() { tv0->computeAt(tv1, 1); tv2->axis(2)->parallelize(ParallelType::Unroll); - tv3->axis(0)->parallelize(ParallelType::BIDx); + tv1->axis(0)->parallelize(ParallelType::BIDx); tv1->axis(-1)->parallelize(ParallelType::TIDx); tv2->axis(-1)->parallelize(ParallelType::TIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); // for(auto expr : fusion.exprs(true)) - // std::cout<split(1, tidx); - // tv3[I0, R1o, R1i{tdix}] = tv2[I0, I1] + // tv3[I0, R1o, R1i{tidx}] = tv2[I0, I1] TensorView* tv6 = tv3->rFactor({-2}); - // tv6[I0, R1o, iR1i{tdix}] = tv2[I0, I1] - // tv3[I0, R1i{tdix}] = tv3[I0, I1] - + // tv6[I0, R1o, iR1i{tidx}] = tv2[I0, I1] + // tv3[I0, R1i{tidx}] = tv3[I0, I1] tv2->computeAt(tv6, 2); // // Compute at inline with tv5 (only 1D) @@ -2251,9 +2264,6 @@ void testGPU_FusionReduction3() { tv3->axis(-1)->parallelize(ParallelType::TIDx); tv6->axis(-1)->parallelize(ParallelType::TIDx); - // GPULower lower(&fusion); - // lower.printKernel(std::cout); - int numel_x = 1025; int numel_y = 129; int bidx = numel_x; @@ -2320,12 +2330,8 @@ void testGPU_FusionSimpleBCast() { TensorView* tv4 = static_cast(mul(tv3, tv2)); fusion.addOutput(tv4); - // tv0->computeAt(tv4, -1); - // tv1->computeAt(tv4, -1); - - // for (auto expr : fusion.exprs(true)) { - // std::cout << expr << std::endl; - // } + tv0->computeAt(tv4, -1); + tv1->computeAt(tv4, -1); // GPULower lower(&fusion); // lower.printKernel(std::cout); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 142f83a0b8e2a..eef671f618b5e 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -249,7 +249,7 @@ StmtNameType Fusion::registerExpr(Expr* expr) { } for (Val* input : expr->inputs()) { - registerVal(input); + assertInFusion(input, "Input to expr is invalid, "); if (uses_.find(input) == uses_.end()) { uses_[input] = {expr}; } else { @@ -258,7 +258,7 @@ StmtNameType Fusion::registerExpr(Expr* expr) { } for (Val* output : expr->outputs()) { - registerVal(output); + assertInFusion(output, "Output to expr is invalid, "); auto it = origin_.find(output); if (it != origin_.end()) { removeExpr(it->second); // will also remove origin entry diff --git a/torch/csrc/jit/codegen/cuda/index_compute.cpp b/torch/csrc/jit/codegen/cuda/index_compute.cpp index 098dce541eec6..ce08528c3d07c 100644 --- a/torch/csrc/jit/codegen/cuda/index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/index_compute.cpp @@ -71,10 +71,6 @@ IndexCompute::IndexCompute(TensorDomain* td, const std::vector& indices) { td->nDims() == indices.size(), "For IndexCompute the number of axes should match the number of dimensions in the TensorDomain."); - TORCH_INTERNAL_ASSERT( - indices.size() == td->nDims(), - "Attempted to modify indices for IndexCompute, but didn't work."); - TORCH_INTERNAL_ASSERT(!td->hasRFactor(), "Not implemented yet."); { @@ -272,15 +268,17 @@ TensorIndex* Index::getGlobalConsumerIndex( auto root_dom = consumer->getRootDomain(); TORCH_INTERNAL_ASSERT( - computed_inds.size() == root_dom.size(), + computed_inds.size() == TensorDomain::noReductions(root_dom).size() || + computed_inds.size() == root_dom.size(), "Dimensionality error in code generator while computing indexing."); - for (decltype(root_dom.size()) i{0}; i < root_dom.size(); i++) { - // Do this backwards so erase offset will be right - auto axis = root_dom.size() - i - 1; - if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast()) - computed_inds.erase(computed_inds.begin() + axis); - } + if (computed_inds.size() == root_dom.size()) + for (decltype(root_dom.size()) i{0}; i < root_dom.size(); i++) { + // Do this backwards so erase offset will be right + auto axis = root_dom.size() - i - 1; + if (root_dom[axis]->isReduction() || root_dom[i]->isBroadcast()) + computed_inds.erase(computed_inds.begin() + axis); + } std::vector strided_inds; for (decltype(computed_inds.size()) i{0}; i < computed_inds.size(); i++) { diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp index 5e9d645d22127..3027784bc5186 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp @@ -30,14 +30,14 @@ Expr* Statement::asExpr() { } // When we create a Val we immediately register them with the active fusion. -Val::Val(ValType _vtype, DataType _dtype) : vtype_{_vtype}, dtype_{_dtype} { +Val::Val(ValType _vtype, DataType _dtype, bool register_val) + : vtype_{_vtype}, dtype_{_dtype} { Fusion* fusion = FusionGuard::getCurFusion(); - if (fusion != nullptr) { - this->name_ = fusion->registerVal(this); - this->fusion_ = fusion; - } else { - TORCH_CHECK(false, "No active fusion group found when creating a Val."); - } + TORCH_CHECK( + fusion != nullptr, "No active fusion group found when creating a Val."); + this->fusion_ = fusion; + if (register_val) + this->name_ = this->fusion_->registerVal(this); } // Traverse origin of all values involved in constructing the provided val. diff --git a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h index eabdab6115e34..92283acac53bc 100644 --- a/torch/csrc/jit/codegen/cuda/ir_base_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_base_nodes.h @@ -152,7 +152,15 @@ struct TORCH_CUDA_API Val : public Statement { virtual ~Val() = default; Val() = delete; - Val(ValType _vtype, DataType _dtype = DataType::Null); + + // We may not want to register this value during Val's constructor. The reason + // for this is that if we register the val, then ina derived constructor try + // to throw, fusion's destructor will get called, but the pointer to this Val + // will be invalid. When fusion tries to delete this value it will cause a seg + // fault, instead of showing the thrown error. + Val(ValType _vtype, + DataType _dtype = DataType::Null, + bool register_val = true); // TODO: Values are unique and not copyable Val(const Val& other) = delete; diff --git a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h index 4c167e4dd18af..9e487a60bbd9a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_interface_nodes.h @@ -189,9 +189,6 @@ struct TORCH_CUDA_API TensorView : public Val { TensorView(const std::shared_ptr& jit_value) : TensorView(jit_value->type()->cast()) {} - // Make an exact copy of this tensor with the same dtype and same domain - TensorView* clone() const; - TensorDomain* domain() const noexcept { return domain_; } diff --git a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h index 94d508b9493ab..bbdc8699d040e 100644 --- a/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h +++ b/torch/csrc/jit/codegen/cuda/ir_internal_nodes.h @@ -359,16 +359,20 @@ struct TORCH_CUDA_API TensorDomain : public Val { TensorDomain& operator=(TensorDomain&& other) = delete; TensorDomain(std::vector _domain); + + TensorDomain( + std::vector _root_domain, + std::vector _domain); + TensorDomain( std::vector _root_domain, + std::vector _rfactor_domain, std::vector _domain); std::vector::size_type nDims() const { return domain_.size(); } - TensorDomain* clone() const; - bool sameAs(const TensorDomain* const other) const; static bool sameAs( @@ -384,19 +388,24 @@ struct TORCH_CUDA_API TensorDomain : public Val { bool hasRFactor() const; const std::vector& noReductions() const noexcept { - return noReductionDomain_; + return no_reduction_domain_; } + const std::vector& noBroadcasts() const noexcept { - return noBCastDomain_; + return no_bcast_domain_; } const std::vector& rootDomain() const noexcept { return root_domain_; }; + const std::vector& rfactorDomain() const noexcept { + return rfactor_domain_; + }; + void resetDomains() { - noReductionDomain_ = noReductions(domain_); - noBCastDomain_ = noBroadcasts(domain_); + no_reduction_domain_ = noReductions(domain_); + no_bcast_domain_ = noBroadcasts(domain_); } // i here is int, as we want to accept negative value and ::size_type can be a @@ -432,8 +441,9 @@ struct TORCH_CUDA_API TensorDomain : public Val { private: const std::vector root_domain_; std::vector domain_; - std::vector noBCastDomain_; - std::vector noReductionDomain_; + std::vector no_bcast_domain_; + std::vector no_reduction_domain_; + const std::vector rfactor_domain_; }; /* diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 17d431fa57d2f..dc7fb9de56e4a 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -214,7 +214,7 @@ IterDomain::IterDomain( bool _reduction_domain, bool _rfactor_domain, bool _broadcast_domain) - : Val(ValType::IterDomain, DataType::Int), + : Val(ValType::IterDomain, DataType::Int, false), start_(_start), extent_(_extent), parallel_method_(_parallel_method), @@ -238,6 +238,7 @@ IterDomain::IterDomain( "Cannot create an iter domain with a start that is not an int but recieved ", _extent, " ."); + this->name_ = fusion_->registerVal(this); } bool IterDomain::sameAs(const IterDomain* const other) const { @@ -290,12 +291,11 @@ std::pair IterDomain::split( Int* fact = new Int(factor); // outer loop size Val* vo = ceilDiv(in->extent(), fact); - Int* so = static_cast(vo); // outer loop IterDomain IterDomain* ido = new IterDomain( new Int(0), - so, + static_cast(vo), in->parallel_method(), in->isReduction(), in->isRFactorProduct(), @@ -334,7 +334,7 @@ TensorDomain::TensorDomain(std::vector _domain) TensorDomain::TensorDomain( std::vector _root_domain, std::vector _domain) - : Val(ValType::TensorDomain), + : Val(ValType::TensorDomain, DataType::Null, false), root_domain_(std::move(_root_domain)), domain_(std::move(_domain)) { std::vector domain_vals(domain_.begin(), domain_.end()); @@ -343,7 +343,7 @@ TensorDomain::TensorDomain( // Validate that the root domain consists of all inputs to _domain // Uncertain if this will hold for RFactor - std::set root_vals(root_domain_.begin(), root_domain_.end()); + std::unordered_set root_vals(root_domain_.begin(), root_domain_.end()); std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { TORCH_INTERNAL_ASSERT( root_vals.find(inp) != root_vals.end(), @@ -353,25 +353,68 @@ TensorDomain::TensorDomain( }); resetDomains(); + + this->name_ = fusion_->registerVal(this); +} + +TensorDomain::TensorDomain( + std::vector _root_domain, + std::vector _rfactor_domain, + std::vector _domain) + : Val(ValType::TensorDomain, DataType::Null, false), + root_domain_(std::move(_root_domain)), + domain_(std::move(_domain)), + rfactor_domain_(std::move(_rfactor_domain)) { + auto inps = IterVisitor::getInputsTo( + std::vector(domain_.begin(), domain_.end())); + + // Validate that the root domain consists of all inputs to _domain + // Uncertain if this will hold for RFactor + + std::unordered_set root_vals(root_domain_.begin(), root_domain_.end()); + std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { + TORCH_INTERNAL_ASSERT( + root_vals.find(inp) != root_vals.end(), + "Invalid tensor domain, ", + inp, + " is an input of domain, but it is not found in the root domain."); + }); + + inps = IterVisitor::getInputsTo( + std::vector(rfactor_domain_.begin(), rfactor_domain_.end())); + std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) { + TORCH_INTERNAL_ASSERT( + root_vals.find(inp) != root_vals.end(), + "Invalid tensor domain, ", + inp, + " is an input of the rfactor domain, but it is not found in the root domain."); + }); + + resetDomains(); + this->name_ = fusion_->registerVal(this); } bool TensorDomain::sameAs(const TensorDomain* const other) const { if (nDims() != other->nDims()) return false; + if (rootDomain().size() != other->rootDomain().size()) + return false; + if (rfactorDomain().size() != other->rfactorDomain().size()) + return false; - for (decltype(nDims()) i = 0; i < nDims(); i++) + for (size_t i = 0; i < nDims(); i++) if (!(axis(i)->sameAs(other->axis(i)))) return false; - return true; -} + for (size_t i = 0; i < rootDomain().size(); i++) + if (!(rootDomain()[i]->sameAs(other->rootDomain()[i]))) + return false; -TensorDomain* TensorDomain::clone() const { - std::vector domain(domain_.size()); - size_t i = 0; - for (auto dom : domain_) - domain[i++] = dom->clone(); - return new TensorDomain(domain); + for (size_t i = 0; i < rfactorDomain().size(); i++) + if (!(rfactorDomain()[i]->sameAs(other->rfactorDomain()[i]))) + return false; + + return true; } bool TensorDomain::sameAs( @@ -388,18 +431,15 @@ bool TensorDomain::sameAs( } bool TensorDomain::hasReduction() const { - return noReductionDomain_.size() != domain_.size(); + return no_reduction_domain_.size() != domain_.size(); } bool TensorDomain::hasBroadcast() const { - return noBCastDomain_.size() != domain_.size(); + return no_bcast_domain_.size() != domain_.size(); } bool TensorDomain::hasRFactor() const { - return std::any_of( - root_domain_.begin(), root_domain_.end(), [](IterDomain* id) { - return id->isRFactorProduct(); - }); + return !rfactor_domain_.empty(); } // i here is int, as we want to accept negative value and ::size_type can be a @@ -506,16 +546,15 @@ std::vector TensorDomain::orderedAs( }); // Check if any adjusted values are < 0, or >= nDims, which are invalid - bool out_of_range = std::any_of( - old2new.begin(), - old2new.end(), - [ndims](std::unordered_map::value_type entry) { - return entry.first < 0 || (unsigned int)entry.first >= ndims || - entry.second < 0 || (unsigned int)entry.second >= ndims; - }); TORCH_CHECK( - !out_of_range, + std::none_of( + old2new.begin(), + old2new.end(), + [ndims](std::unordered_map::value_type entry) { + return entry.first < 0 || (unsigned int)entry.first >= ndims || + entry.second < 0 || (unsigned int)entry.second >= ndims; + }), "Reorder axes are not within the number of dimensions of the provided domain."); // Going to use sets, to see if any duplicate values are in the map. @@ -552,14 +591,6 @@ std::vector TensorDomain::orderedAs( for (std::pair elem : old2new) { int old_pos = elem.first; int new_pos = elem.second; - - TORCH_INTERNAL_ASSERT( - old_pos >= 0 && old_pos < ndims && new_pos >= 0 && new_pos < ndims, - "Error occured in reorder, somehow axes are not in expected range."); - - if (new2old[new_pos] != -1) - TORCH_CHECK(false, "Reorder found duplicate destination positions."); - new2old[new_pos] = old_pos; } @@ -567,11 +598,6 @@ std::vector TensorDomain::orderedAs( std::set old_positions(new2old.begin(), new2old.end()); old_positions.erase(-1); - // Make sure we have all of them, and no duplicates were found - if (old_positions.size() != old2new.size()) - TORCH_INTERNAL_ASSERT( - false, "Reorder found duplicate destination positions."); - // All available new positions std::set all_positions; for (decltype(ndims) i{0}; i < ndims; i++) @@ -653,46 +679,46 @@ bool TensorDomain::hasReduction(const std::vector& td) { // pair is in order where second is the consumer of first std::pair TensorDomain::rFactor( const std::vector& axes_) { - // std::vector axes(axes_.size()); - - // auto ndims = nDims(); - // std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) { - // return i < 0 ? i + ndims : i; - // }); - - // TORCH_CHECK( - // std::none_of( - // axes.begin(), - // axes.end(), - // [ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }), - // "RFactor axes less than 0 or >= ndims."); - - // TORCH_CHECK( - // !hasRFactor(), "Cannot call rfactor on the same tensor domain twice."); - - // std::set axes_set(axes.begin(), axes.end()); - - // bool rfactor_found = false; - // bool reduction_found = false; - // for (decltype(nDims()) i{0}; i < nDims(); i++) { - // if (axis(i)->isReduction()) { - // if (axes_set.find(i) != axes_set.end()) - // rfactor_found = true; - // else - // reduction_found = true; - // } - // } - - // TORCH_CHECK( - // rfactor_found && reduction_found, - // "Invalid rfactor found, rfactor must be provided at least one reduction - // axis, but not all reduction axes."); - - // return std::pair{ - // TransformRFactor::runReplay(this, axes), - // TransformRFactor::runReplay2(this, axes)}; - - TORCH_INTERNAL_ASSERT(false, "NIY."); + std::vector axes(axes_.size()); + + auto ndims = nDims(); + std::transform(axes_.begin(), axes_.end(), axes.begin(), [ndims](int i) { + return i < 0 ? i + ndims : i; + }); + + TORCH_CHECK( + std::none_of( + axes.begin(), + axes.end(), + [ndims](int i) { return i < 0 || (unsigned int)i >= ndims; }), + "RFactor axes less than 0 or >= ndims."); + + // We might be able to lift this constraint in some instances, but needs more + // investigation. + TORCH_CHECK( + !hasRFactor(), "Cannot call rfactor on the same tensor domain twice."); + + std::unordered_set axes_set(axes.begin(), axes.end()); + + bool rfactor_found = false; + bool reduction_found = false; + for (decltype(nDims()) i{0}; i < nDims(); i++) { + if (axis(i)->isReduction()) { + if (axes_set.find(i) != axes_set.end()) { + rfactor_found = true; + } else { + reduction_found = true; + } + } + } + + TORCH_CHECK( + rfactor_found && reduction_found, + "Invalid rfactor found, rfactor must be provided at least one reduction axis, but not all reduction axes."); + + return std::pair{ + TransformRFactor::runReplay(this, axes), + TransformRFactor::runReplay2(this, axes)}; } Split::Split( diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 4e37404d37542..1dcb6bd3f0faf 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -43,6 +43,7 @@ struct TORCH_CUDA_API IterVisitor : public OptOutDispatch { IterVisitor(IterVisitor&& other) = default; IterVisitor& operator=(IterVisitor&& other) = default; + protected: // Functions return nodes in reverse order to be added to the to_visit queue // These functions will start at outputs and propagate up through the DAG // to inputs based on depth first traversal. Next could be called on a node diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index 73578a0ca2732..ef0ff560a5a26 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -83,7 +83,6 @@ Statement* GPULower::mutate(ForLoop* fl) { } active_scope = prev_scope; - if (is_mutated) { auto newFL = new ForLoop( fl->index(), fl->iter_domain(), mutated_exprs, fl->parentScope()); diff --git a/torch/csrc/jit/codegen/cuda/lower_loops.cpp b/torch/csrc/jit/codegen/cuda/lower_loops.cpp index 257e2a43572b7..755882fae7c91 100644 --- a/torch/csrc/jit/codegen/cuda/lower_loops.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_loops.cpp @@ -155,7 +155,6 @@ void UnrollPass::handle(ForLoop* fl) { // Make predicates for the unrolling, and the epilogue Bool* unroll_predicate = getPredicate(out, unroll_pred_inds); - // Make the IfThenElse controlling the unrolling IfThenElse* unroll_ite = new IfThenElse(unroll_predicate, {}, {}, first_unroll->parentScope()); diff --git a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp index d12b14cfdc5b6..8fd944c9117a9 100644 --- a/torch/csrc/jit/codegen/cuda/predicate_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/predicate_compute.cpp @@ -21,8 +21,13 @@ std::vector PredicateCompute::computePredicates(const TensorIndex* ti) { const std::vector& root = tv->getRootDomain(); std::vector preds; - if (FusionGuard::getCurFusion()->origin(tv->domain()) == nullptr && - tv->nDims() == ti->nDims()) + + bool no_pred_needed = true; + for (auto id : tv->domain()->domain()) + if (id->getOrigin() != nullptr) + no_pred_needed = false; + + if (no_pred_needed) return preds; TORCH_INTERNAL_ASSERT(root.size() == ti->nDims()); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 0c707b8d4d454..80c4d98defb3b 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -22,7 +22,9 @@ TensorView::TensorView(TensorDomain* _domain, DataType dtype) : Val(ValType::TensorView, dtype), domain_(_domain) {} TensorView::TensorView(const std::shared_ptr& tensor_type) - : Val(ValType::TensorView, aten_opt_type_map(tensor_type->scalarType())) { + : Val(ValType::TensorView, + aten_opt_type_map(tensor_type->scalarType()), + false) { std::vector sizes; TORCH_CHECK( tensor_type->dim().has_value(), "Requires static rank for Tensor"); @@ -32,15 +34,8 @@ TensorView::TensorView(const std::shared_ptr& tensor_type) sizes.push_back(new IterDomain(new Int(0), new Int())); } domain_ = new TensorDomain(sizes); -} - -TensorView* TensorView::clone() const { - TensorView* new_view = - new TensorView(domain()->clone(), getDataType().value()); - new_view->setComputeAt(compute_at_view_, (int)relative_compute_at_axis_); - new_view->memory_type_ = getMemoryType(); - return new_view; + this->name_ = fusion_->registerVal(this); } bool TensorView::hasReduction() const { diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.cpp b/torch/csrc/jit/codegen/cuda/transform_iter.cpp index 94229ed41c73c..9eaf835df8f10 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_iter.cpp @@ -1,12 +1,192 @@ -// #include -// #include -// #include -// #include - -#include +#include namespace torch { namespace jit { -namespace fuser {} // namespace fuser +namespace fuser { + +// Transform dispatch +void ReplayTransformations::handle(Expr* e) { + switch (e->getExprType().value()) { + case (ExprType::Split): + case (ExprType::Merge): + break; + default: + TORCH_INTERNAL_ASSERT( + false, "Invalid expr type found in transform traversal."); + } + IterVisitor::handle(e); +} + +// We're going to replay this split operation on the corresponding ID +void ReplayTransformations::handle(Split* s) { + // Grab our input to the split node + auto id_in = s->in(); + + // Make sure we have a corresponding entry in our map pointing to the ID we're + // going to replay the split on + auto it = id_map_.find(id_in); + if (it == id_map_.end()) { + if (check_all_ops_run_) { + TORCH_INTERNAL_ASSERT( + false, "Transform traversal failed, dependencies not met."); + } else { + return; + } + } + + auto mapped = (*it).second; + TORCH_INTERNAL_ASSERT( + s->factor()->isConst(), + "Transform traversal does not support splitting on non-const values."); + // Make sure this ID is a leaf ID (meaning it has no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + // Replay the split onto mapped + auto outs = IterDomain::split(mapped, s->factor()->value().value()); + // Remove mapped from the leaf IDs + leaf_ids_.erase(mapped); + + // Add outputs to leaf IDs + leaf_ids_[outs.first] = counter++; + leaf_ids_[outs.second] = counter++; + + // Update our ID map to include these outputs + id_map_[s->outer()] = outs.first; + id_map_[s->inner()] = outs.second; +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplayTransformations::handle(Merge* m) { + // Grab the inputs to the merge node + auto id_outer = m->outer(); + auto id_inner = m->inner(); + + // Make sure we have a corresponding entry in our map pointing to the IDs + // we're going to replay the merge on + auto it_outer = id_map_.find(id_outer); + auto it_inner = id_map_.find(id_inner); + if (it_outer == id_map_.end() || it_inner == id_map_.end()) { + if (check_all_ops_run_) { + TORCH_INTERNAL_ASSERT( + false, "Transform traversal failed, dependencies not met."); + } else { + return; + } + } + + // Grab the IDs we're going to replay this merge on + auto id_outer_mapped = (*it_outer).second; + auto id_inner_mapped = (*it_inner).second; + + // Make sure these IDs are leaf IDs (meaning they have no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && + leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), + "Transform traversal failed, tried to replay with ", + id_outer_mapped, + " and ", + id_inner_mapped, + " however one or both are not leaf nodes."); + + // Replay the merge operation + auto out = IterDomain::merge(id_outer_mapped, id_inner_mapped); + + // Remove inputs from the leaf IDs + leaf_ids_.erase(id_outer_mapped); + leaf_ids_.erase(id_inner_mapped); + + // Add the output to the leaf IDs + leaf_ids_[out] = counter++; + + // Update our ID map with the replayed output + id_map_[m->out()] = out; +} + +ReplayTransformations::ReplayTransformations( + const std::vector& _target_domain, + std::unordered_map _id_map, + bool _check_all_ops_run) + : target_domain_(std::move(_target_domain)), + id_map_(std::move(_id_map)), + check_all_ops_run_(_check_all_ops_run) { + // Make sure id_map has all the inputs needed to replay target_domain + auto inps = IterVisitor::getInputsTo( + std::vector(target_domain_.begin(), target_domain_.end())); + + if (check_all_ops_run_) + std::for_each(inps.begin(), inps.end(), [this](Val* val) { + TORCH_INTERNAL_ASSERT( + val->getValType().value() == ValType::IterDomain, + "Expected IterDomain only for Replay Transformations, but found ", + val); + IterDomain* id = static_cast(val); + TORCH_INTERNAL_ASSERT( + this->id_map_.find(id) != this->id_map_.end(), + "Could not find required input: ", + id, + " in provided id_map."); + }); + + // Set all the leaf nodes for tracking, all ids start as a leaf and will be + // updated based on the transformations + for (auto entry : id_map_) + leaf_ids_[entry.second] = counter++; +} + +// Replays outputs that were generated from ids.first on ids.second +void ReplayTransformations::runReplay() { + TORCH_INTERNAL_ASSERT( + !ran_replay, + "Cannot run replay twice without creating a new Replay Class."); + ran_replay = true; + if (target_domain_.empty() || id_map_.empty()) + return; + + // Switch outDomain to a vector to start the traversal + std::vector traversal_vals( + target_domain_.begin(), target_domain_.end()); + traverseFrom(traversal_vals[0]->fusion(), traversal_vals); + + if (check_all_ops_run_) + TORCH_INTERNAL_ASSERT( + leaf_ids_.size() >= target_domain_.size(), + "Transform traversal failed, did not find enough output IterDomains."); + + // Validate replay + for (auto out : target_domain_) { + auto it_replayed = id_map_.find(out); + if (it_replayed == id_map_.end()) { + if (check_all_ops_run_) { + TORCH_INTERNAL_ASSERT( + false, + "Transform traversal failed, could not find expected output."); + } + continue; + } + + auto id_replayed = (*it_replayed).second; + auto it_leaf = leaf_ids_.find(id_replayed); + TORCH_INTERNAL_ASSERT( + it_leaf != leaf_ids_.end(), + "Transform Traversal failed, expected matched output to be a leaf of the replay, but was not."); + } + + // Populate leaf_vec_ in a deterministic manner. This is deterministic + // because size_t in leaf_ids is filled based on operation order. + std::set, id_int_lt> ordered_set; + for (auto entry : leaf_ids_) + ordered_set.emplace(entry); + + leaf_vec_.clear(); + leaf_vec_.resize(ordered_set.size()); + std::transform( + ordered_set.begin(), + ordered_set.end(), + leaf_vec_.begin(), + [](std::pair entry) { return entry.first; }); +} +} // namespace fuser } // namespace jit } // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index 1334162f232f7..ed491cfa1ff18 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -3,8 +3,8 @@ #include #include +#include #include - #include namespace torch { @@ -13,6 +13,7 @@ namespace fuser { namespace { +// Enable pair in a set, size_t must be unique in set struct id_int_lt { bool operator()( const std::pair& first, @@ -24,257 +25,67 @@ struct id_int_lt { } // namespace struct TORCH_CUDA_API ReplayTransformations : public IterVisitor { - private: - std::vector target_domain_; + protected: + const std::vector& target_domain_; std::unordered_map id_map_; std::unordered_map leaf_ids_; std::vector leaf_vec_; size_t counter = 0; - + bool check_all_ops_run_ = true; + bool ran_replay = false; // Mark if replay has been run using IterVisitor::handle; - void handle(Expr* e) override { - switch (e->getExprType().value()) { - case (ExprType::Split): - case (ExprType::Merge): - break; - default: - TORCH_INTERNAL_ASSERT( - false, "Invalid expr type found in transform traversal."); - } - IterVisitor::handle(e); - } + // Transform dispatch + void handle(Expr* e) override; // TODO: HANDLE RFACTOR DOMAINS - virtual void handle(Split* s) override { - auto id_in = s->in(); - auto it = id_map_.find(id_in); - TORCH_INTERNAL_ASSERT( - it != id_map_.end(), - "Transform traversal failed, dependencies not met."); - auto mapped = (*it).second; - TORCH_INTERNAL_ASSERT( - s->factor()->isConst(), - "Transform traversal does not support splitting on non-const values."); - auto outs = IterDomain::split(mapped, s->factor()->value().value()); - TORCH_INTERNAL_ASSERT( - leaf_ids_.find(mapped) != leaf_ids_.end(), - "Transform traversal failed, modified a node but it was not a leaf node."); - leaf_ids_.erase(mapped); - leaf_ids_[outs.first] = counter++; - leaf_ids_[outs.second] = counter++; - id_map_[s->outer()] = outs.first; - id_map_[s->inner()] = outs.second; - } - - virtual void handle(Merge* m) override { - auto id_outer = m->outer(); - auto id_inner = m->inner(); - auto it_outer = id_map_.find(id_outer); - auto it_inner = id_map_.find(id_inner); - TORCH_INTERNAL_ASSERT( - it_outer != id_map_.end() && it_inner != id_map_.end(), - "Transform traversal failed, dependencies not met."); + // We're going to replay this split operation on the corresponding ID + virtual void handle(Split* s) override; - auto id_outer_mapped = (*it_outer).second; - auto id_inner_mapped = (*it_inner).second; - - auto out = IterDomain::merge(id_outer_mapped, id_inner_mapped); - - TORCH_INTERNAL_ASSERT( - leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && - leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), - "Transform traversal failed, modified ", - id_outer_mapped, - " and ", - id_inner_mapped, - " however one or both are not leaf nodes."); - - leaf_ids_.erase(id_outer_mapped); - leaf_ids_.erase(id_inner_mapped); - leaf_ids_[out] = counter++; - id_map_[m->out()] = out; - } - - // Replays outputs that were generated from ids.first on ids.second - void runReplay() { - if (target_domain_.empty() || id_map_.empty()) - return; - - // Switch outDomain to a vector to start the traversal - std::vector traversal_vals( - target_domain_.begin(), target_domain_.end()); - traverseFrom(traversal_vals[0]->fusion(), traversal_vals); - - TORCH_INTERNAL_ASSERT( - leaf_ids_.size() >= target_domain_.size(), - "Transform traversal failed, did not find enough output IterDomains."); - - // Validate replay - size_t it = 0; - for (auto out : target_domain_) { - auto it_replayed = id_map_.find(out); - TORCH_INTERNAL_ASSERT( - it_replayed != id_map_.end(), - "Transform traversal failed, could not find expected output."); - auto id_replayed = (*it_replayed).second; - auto it_leaf = leaf_ids_.find(id_replayed); - TORCH_INTERNAL_ASSERT( - it_leaf != leaf_ids_.end(), - "Transform Traversal failed, expected matched output to be a leaf of the replay, but was not."); - } - } + // We're going to replay this merge operation on the corresponding IDs + virtual void handle(Merge* m) override; public: + // Uses the history of _target_domain, and replays that history using the + // provided map target_domain contains the history we want replayed, and + // id_map maps IterDomains in that history to the IterDomains we want it + // replayed on. check_all_ops_run will cause the replay to error if we can't + // play any operation in target_domain's history because the IDs are not in + // the id_map. If check_all_ops_run = false, replay will replay everything it + // can, and ignore operations it can't. ReplayTransformations( - std::vector _target_domain, - std::unordered_map _id_map) - : target_domain_(std::move(_target_domain)), id_map_(std::move(_id_map)) { - // Make sure id_map has all the inputs needed to replay target_domain - auto inps = IterVisitor::getInputsTo( - std::vector(target_domain_.begin(), target_domain_.end())); - - std::for_each(inps.begin(), inps.end(), [this](Val* val) { - TORCH_INTERNAL_ASSERT( - val->getValType().value() == ValType::IterDomain, - "Expected IterDomain only for Replay Transformations, but found ", - val); - IterDomain* id = static_cast(val); - TORCH_INTERNAL_ASSERT( - this->id_map_.find(id) != this->id_map_.end(), - "Could not find required input: ", - id, - " in provided id_map."); - }); - - // Set all the leaf nodes for tracking, all ids start as a leaf and will be - // updated based on the transformations - for (auto entry : id_map_) - leaf_ids_[entry.second] = counter++; + const std::vector& _target_domain, + std::unordered_map _id_map, + bool _check_all_ops_run = true); - runReplay(); - - // Populate leaf_vec_ in a deterministic manner. This is deterministic - // because size_t in leaf_ids is filled based on operation order. - std::set, id_int_lt> ordered_set; - for (auto entry : leaf_ids_) - ordered_set.emplace(entry); - - leaf_vec_.clear(); - leaf_vec_.resize(ordered_set.size()); - std::transform( - ordered_set.begin(), - ordered_set.end(), - leaf_vec_.begin(), - [](std::pair entry) { return entry.first; }); - } + // Replays outputs that were generated from ids.first on ids.second + void runReplay(); // Returns map from provided target domain to their corresponding IDs - const std::unordered_map& getReplay() const - noexcept { + const std::unordered_map& getReplay() { + if (!ran_replay) + runReplay(); return id_map_; } - // - const std::unordered_map& getUnorderedLeafIDs() const - noexcept { + // Returns leaf_ids_ the size_t marks the order in which they were put into + // the map, this is part of the structure because it's used to generate the + // order from 'getLeafIDs' + const std::unordered_map& getUnorderedLeafIDs() { + if (!ran_replay) + runReplay(); return leaf_ids_; } // Returns all terminating IDs that resulted from the replay. Leaf IDs are run // to run deterministic, but otherwise in no specific order. - const std::vector& getLeafIDs() const noexcept { + const std::vector& getLeafIDs() { + if (!ran_replay) + runReplay(); return leaf_vec_; } }; -/* - * TransformIter iterates on the split/merge graph of TensorDomain - * - * Running backward will execute these Exprs in reverse order. If you record - * these events (generate_record=true) you can then replay them on another - * tensor domain. - */ -struct TORCH_CUDA_API TransformBackwardIter : public IterVisitor { - protected: - virtual TensorDomain* replayBackward(Split*, TensorDomain*); - virtual TensorDomain* replayBackward(Merge*, TensorDomain*); - - // dispatch - TensorDomain* replayBackward(Expr*, TensorDomain*); - - virtual TensorDomain* runBackward(TensorDomain*); - - public: - // Returns transformation exprs in forward order - static std::vector getHistory(TensorDomain*) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - // TODO: make td const - static TensorDomain* getRoot(TensorDomain* td) { - { TORCH_INTERNAL_ASSERT(false, "NIY."); } - // TransformIter ti; - // return ti.runBackward(td); - } - - // Takes influence vector of bools, tracks them back to propagate true to root - // axes that were modified into td axes matching marked influence vector. - static std::vector getRootInfluence( - TensorDomain* td, - const std::vector& influence) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - static std::vector replayBackwardInfluence( - const std::vector& history, - const std::vector& td_influence) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - // Runs through history, applying only on influence to track how modifications - // would influence the original axes. - static std::vector replayInfluence( - const std::vector& history, - const std::vector& td_influence) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - // Goes through history and applies it to td, with the axis_map provided. - // Axis_map entries of -1 mean those axes won't be modified - static TensorDomain* replay( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - // Takes td, and replays history backwards on it to create a new root tensor - // domain using axis_map. Entries in axis_map == -1 will not be modified - static TensorDomain* replayBackward( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - static TensorDomain* replaySelf( - TensorDomain* td, - const std::vector& history, - const std::vector& axis_map) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } - - // getRFactorRoot does not reapply any transformations. It simply searches - // through the history of td from its root domain and tries to find a point - // where we stop transforming axes marked as rfactor. This works because all - // rfactor transformations are pushed to the begining of td's history by the - // RFactor transformation itself. - static TensorDomain* getRFactorRoot(TensorDomain* td) { - TORCH_INTERNAL_ASSERT(false, "NIY."); - } -}; - } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/transform_replay.cpp b/torch/csrc/jit/codegen/cuda/transform_replay.cpp index 35a858a1389fa..3288268dc8ea2 100644 --- a/torch/csrc/jit/codegen/cuda/transform_replay.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_replay.cpp @@ -1,8 +1,10 @@ #include +#include #include #include #include #include +#include #include @@ -10,38 +12,121 @@ namespace torch { namespace jit { namespace fuser { -// // Replay producer as consumer. -// TensorDomain* TransformReplay::fullSelfReplay( -// TensorDomain* self, -// TensorDomain* self_copy) { -// // Want producer root with no reductions, rfactor included -// TensorDomain* self_root = self->rootDomain(); +namespace { -// // Want full consumer root, even before rfactor -// TensorDomain* self_copy_root = self_copy->rootDomain(); +struct TORCH_CUDA_API ReplaySelf : public ReplayTransformations { + private: + // Took a good bit of this from ReplayTransformations::handle(Split...) + void handle(Split* s) override { + // Grab input to the split operation + auto id_in = s->in(); -// TORCH_INTERNAL_ASSERT( -// self_root->nDims(), self_copy_root->nDims(), "Invalid self replay."); + // Grab our mapping of that ID to the one we're replaying + auto it = id_map_.find(id_in); -// for (decltype(self_root->nDims()) i{0}; i < self_root->nDims(); i++) -// TORCH_INTERNAL_ASSERT( -// self_root->axis(i)->parallel_method() == -// self_copy_root->axis(i)->parallel_method() && -// self_root->axis(i)->isReduction() == -// self_copy_root->axis(i)->isReduction() && -// self_root->axis(i)->start() == self_copy_root->axis(i)->start(), -// "Invalid self replay detected, root domain does not match."); + // Make sure it exists in the map + TORCH_INTERNAL_ASSERT( + it != id_map_.end(), + "Transform traversal failed, dependencies not met."); + // Grab the ID we're going to replay on + auto mapped = it->second; + + // This ID should be a leaf ID (meaning it has no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + // outer loop size + Val* oe = ceilDiv(mapped->extent(), s->factor()); + + // Manually replay the split, following the output of the operations. + // This is so rfactor ops are replayed correctly. + IterDomain* ido = new IterDomain( + new Int(0), + static_cast(oe), + s->outer()->parallel_method(), + s->outer()->isReduction(), + s->outer()->isRFactorProduct(), + s->outer()->isBroadcast()); + + // inner IterDomain + IterDomain* idi = new IterDomain( + new Int(0), + s->factor(), + s->inner()->parallel_method(), + s->inner()->isReduction(), + s->inner()->isRFactorProduct(), + s->inner()->isBroadcast()); + + // Generate the split node + new Split(ido, idi, mapped, s->factor()); + + // Remove mapped id from leaf IDs + leaf_ids_.erase(mapped); + + // Add outputs to leaf IDs + leaf_ids_[ido] = counter++; + leaf_ids_[idi] = counter++; + + // Update our ID map to include these outputs + id_map_[s->outer()] = ido; + id_map_[s->inner()] = idi; + } -// std::vector axis_map(self_root->nDims()); -// std::iota(axis_map.begin(), axis_map.end(), 0); + void handle(Merge* m) override { + auto id_outer = m->outer(); + auto id_inner = m->inner(); -// // Finally replay producer as consumer on marked axes + auto it_outer = id_map_.find(id_outer); + auto it_inner = id_map_.find(id_inner); -// auto replayed = TransformIter::replay( -// self_copy_root, TransformIter::getHistory(self), axis_map); + TORCH_INTERNAL_ASSERT( + it_outer != id_map_.end() && it_inner != id_map_.end(), + "Transform traversal failed, dependencies not met."); -// return replayed; -// } + auto id_outer_mapped = it_outer->second; + auto id_inner_mapped = it_inner->second; + + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && + leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), + "Transform traversal failed, modified ", + id_outer_mapped, + " and ", + id_inner_mapped, + " however one or both are not leaf nodes."); + + Val* merged_id_size = + mul(id_outer_mapped->extent(), id_inner_mapped->extent()); + + IterDomain* merged_id = new IterDomain( + new Int(0), + static_cast(merged_id_size), + m->out()->parallel_method(), + m->out()->isReduction(), + m->out()->isRFactorProduct(), + m->out()->isBroadcast()); + + new Merge(merged_id, id_outer_mapped, id_inner_mapped); + + // Remove inputs from the leaf IDs + leaf_ids_.erase(id_outer_mapped); + leaf_ids_.erase(id_inner_mapped); + + // Add the output to the leaf IDs + leaf_ids_[merged_id] = counter++; + + id_map_[m->out()] = merged_id; + } + + public: + ReplaySelf( + const std::vector& _target_domain, + std::unordered_map _id_map) + : ReplayTransformations(_target_domain, _id_map, false) {} +}; + +} // namespace // Self replay. TensorDomain* TransformReplay::fullSelfReplay( @@ -73,7 +158,7 @@ TensorDomain* TransformReplay::fullSelfReplay( } // Replay producer dimensions. - ReplayTransformations replay(self->domain(), axis_map); + ReplaySelf replay(self->domain(), axis_map); std::vector new_domain(self->nDims(), nullptr); { @@ -83,7 +168,7 @@ TensorDomain* TransformReplay::fullSelfReplay( TORCH_INTERNAL_ASSERT( it != replay.getReplay().end(), "Error during replay, didn't replay an axis."); - new_domain[i++] = (*it).second; + new_domain[i++] = it->second; } } @@ -91,6 +176,11 @@ TensorDomain* TransformReplay::fullSelfReplay( } // Replay producer as consumer. +// Producer could have rfactor axes which consumer may want replayed. We can +// "replay" them as long as it doesn't modify the root rfactor axes. What we +// really want to do is validate if we replayed these axes to the ones they +// mapped to in the consumer the operations would all be the same. then we want +// to start the replay of the producer from the rfactor root axes, not the root. TensorDomain* TransformReplay::replayPasC( TensorDomain* producer, TensorDomain* consumer, @@ -125,6 +215,9 @@ TensorDomain* TransformReplay::replayPasC( // Grab root domains of producer and consumer std::vector consumer_root = consumer->rootDomain(); std::vector producer_root = producer->rootDomain(); + if (producer->hasRFactor()) + producer_root = producer->rfactorDomain(); + // Track which root axes in producer we send to replay std::unordered_set producer_mapped_roots; // Map related axes from producer and consumer roots. Make sure we go to the @@ -154,13 +247,23 @@ TensorDomain* TransformReplay::replayPasC( } } + // Instead of replaying from the root, lets try to forward the history of + // consumer if they match ops on producer. Enforce if we modify an rfactor + // axis that those ops match. + root_axis_map = + BestEffortReplay::replay(producer->domain(), consumer_ids, root_axis_map); + // Replay producer dimensions. - ReplayTransformations replay_producer(consumer_ids, root_axis_map); + ReplayTransformations replay_producer(consumer_ids, root_axis_map, false); - // replay_producer now contains mappings from consumer axes to their replayed - // counter parts in producer (including intermediate IDs, not just those in - // consumer_root or conusmer). replay_producer also has all the leaf - // IterDomains from the replay. + // We could try to continue replaying anything that isn't required for the + // computeAt, but for now we don't. + + // replay_producer now contains mappings from consumer axes to their + // replayed counter parts in producer (including intermediate IDs, not + // just those in consumer_root or conusmer). replay_producer also has all + // the leaf IterDomains from the replay. Some of the leaf iter domains could + // be the same as those in the original domain. // Find all axes that were not modified during the replay. std::vector unmodified_producer_axes; @@ -176,9 +279,7 @@ TensorDomain* TransformReplay::replayPasC( continue; IterDomain* inp_id = static_cast(inp); // if ( we sent this root id to replay && it was modified ) - if (producer_mapped_roots.find(inp_id) != producer_mapped_roots.end() && - replay_producer.getUnorderedLeafIDs().find(inp_id) == - replay_producer.getUnorderedLeafIDs().end()) { + if (producer_mapped_roots.find(inp_id) != producer_mapped_roots.end()) { modified = true; break; } @@ -189,17 +290,30 @@ TensorDomain* TransformReplay::replayPasC( unmodified_producer_axes.emplace_back(producer_id); } - // (1) replay_producer.getReplay holds mappings from axes in consumer -> - // generated axes in producer (2) replay_producer.getLeafIDs holds a - // determinstica ordering of axes in (1), and all other leaf axes created in - // generating the above (3) unmodified_producer_axes holds axes that didn't - // have to be modified to generate (1) + /* + * Accumulate axes in to the new domain in the following order, making sure to + * avoid any duplicates: + * + * (1) replay_producer.getReplay holds mappings from axes in consumer -> + * generated axes in producer + * + * (2) replay_producer.getLeafIDs holds a determinstic ordering of axes in + * (1), and all other leaf axes created in generating the above. Next will be + * any leaves that can be mapped to the original consumer domain. This is not + * an order we should guarentee but can make life simpler. + * + * (3) Any axes in getLeafIds that were in the original producer domain. + * + * (4) Remaining leaf axes. + * + * (5) unmodified_producer_axes holds axes that didn't have to be modified to + * generate (1) + */ - // Accumulate new domain in this vector: std::vector new_IDs; + std::unordered_set used_IDs; // Add axes in (1) - std::unordered_set used_IDs; for (auto c_id : consumer_ids) { auto it = replay_producer.getReplay().find(c_id); TORCH_INTERNAL_ASSERT( @@ -207,159 +321,43 @@ TensorDomain* TransformReplay::replayPasC( "Could not find axis, ", c_id, ", requested in replay."); - new_IDs.push_back((*it).second); - used_IDs.emplace((*it).second); + new_IDs.push_back(it->second); + used_IDs.emplace(it->second); } // Add axes in (2) + for (auto c_id : consumer->domain()) { + auto it = replay_producer.getReplay().find(c_id); + if (it == replay_producer.getReplay().end()) + continue; + auto id = it->second; + if (used_IDs.find(id) == used_IDs.end()) { + new_IDs.push_back(id); + used_IDs.emplace(id); + } + } + + // Add axes in (3) + for (auto id : producer->domain()) + if (replay_producer.getUnorderedLeafIDs().find(id) != + replay_producer.getUnorderedLeafIDs().end() && + used_IDs.find(id) == used_IDs.end()) { + new_IDs.push_back(id); + used_IDs.emplace(id); + } + + // Add axes in (4) for (auto leaf : replay_producer.getLeafIDs()) if (used_IDs.find(leaf) == used_IDs.end()) new_IDs.push_back(leaf); - // Add axes in (3) + // Add axes in (5) for (auto unmodified : unmodified_producer_axes) - if (used_IDs.find(unmodified) == used_IDs.end()) - new_IDs.push_back(unmodified); - - TensorDomain* replayed = new TensorDomain(producer_root, new_IDs); + new_IDs.push_back(unmodified); + TensorDomain* replayed = new TensorDomain( + producer->rootDomain(), producer->rfactorDomain(), new_IDs); return replayed; - - // KEEPING BELOW FOR NOW FOR REFERENCE WHEN DOING RFACTOR! - // // Consumer in rfactor cases is based off producer's rfactor root, not - // // producer's root - // TensorDomain* producer_rfactor_root = - // TransformIter::getRFactorRoot(producer); - - // // Want full consumer root, even before rfactor - // TensorDomain* consumer_root = TransformIter::getRoot(consumer); - - // // We want to see which axes in the consumer root were modified to create - // axes - // // < consumer_compute_at_axis - // std::vector consumer_influence(consumer->nDims(), false); - // for (int i = 0; i < consumer_compute_at_axis; i++) - // consumer_influence[i] = true; - - // // Check which axes in ref_root need to be modified to honor - // transformations - // // to compute at axis - // std::vector consumer_root_influence = - // TransformIter::getRootInfluence(consumer, consumer_influence); - - // // We have the influence on the consumer root, we need it on producer, we - // // want to keep those axes that don't need to be modified by the replay - // std::vector producer_rfactor_root_influence( - // producer_rfactor_root->nDims(), false); - - // // Map is based on producer - // std::vector replay_axis_map(consumer_root->nDims(), -1); - // // Setup producer_rfactor_root_influence vector on root for replay - // size_t ip = 0, ic = 0; - - // while (ip < producer_rfactor_root_influence.size() && - // ic < consumer_root->nDims()) { - // bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); - // bool is_bcast = consumer_root->axis(ic)->isBroadcast(); - // if (is_reduction) { - // producer_rfactor_root_influence[ip++] = false; - // } else if (is_bcast) { - // replay_axis_map[ic++] = -1; - // } else { - // if (consumer_root_influence[ic]) { - // replay_axis_map[ic] = ip; - // } else { - // replay_axis_map[ic] = -1; - // } - // producer_rfactor_root_influence[ip++] = consumer_root_influence[ic++]; - // } - // } - - // for (decltype(producer_rfactor_root->nDims()) i{0}; - // i < producer_rfactor_root->nDims(); - // i++) - // TORCH_INTERNAL_ASSERT( - // !(producer_rfactor_root_influence[i] && - // producer_rfactor_root->axis(i)->isRFactorProduct()), - // "An illegal attempt to modify an rfactor axis detected."); - - // // We should have hit the end of the consumer root domain - // TORCH_INTERNAL_ASSERT( - // ic == consumer_root->nDims() || - // (ic < consumer_root->nDims() ? - // consumer_root->axis(ic)->isBroadcast() - // : false), - // "Error when trying to run replay, didn't reach end of consumer/target - // root."); - - // TORCH_INTERNAL_ASSERT( - // producer_rfactor_root_influence.size() == - // producer_rfactor_root->nDims(), "Error detected during replay, expected - // matching sizes of influence map to root dimensions."); - - // auto producer_root_influence = TransformIter::getRootInfluence( - // producer_rfactor_root, producer_rfactor_root_influence); - - // TensorDomain* producer_root = - // TransformIter::getRoot(producer_rfactor_root); - - // std::vector producer_replay_map(producer_root->nDims()); - // for (decltype(producer_replay_map.size()) i{0}; - // i < producer_replay_map.size(); - // i++) { - // if (producer_root->axis(i)->isRFactorProduct()) { - // producer_replay_map[i] = i; - // } else { - // producer_replay_map[i] = producer_root_influence[i] ? -1 : i; - // } - // } - - // // Replay axes that won't be modified by transform replay - // TensorDomain* producer_replay_root = TransformIter::replaySelf( - // producer, TransformIter::getHistory(producer), producer_replay_map); - - // // Record axes positions. - // std::unordered_map new_position; - // for (decltype(producer_replay_root->nDims()) i{0}; - // i < producer_replay_root->nDims(); - // i++) - // new_position[producer_replay_root->axis(i)] = i; - - // std::unordered_map root_axis_map; - // // reorder producer_replay_root to respect replay_axis_map - // for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); - // i++) { - // if (replay_axis_map[i] == -1) - // continue; - // auto ax = producer_root->axis(replay_axis_map[i]); - // TORCH_INTERNAL_ASSERT( - // new_position.find(ax) != new_position.end(), - // "Error hit during transform replay, could not find ", - // ax, - // " expected in root domain."); - // root_axis_map[new_position[ax]] = replay_axis_map[i]; - // } - - // // root_axis_map is now mapping from producer_replay_root -> consumer_root - // // Take producer_replay_root transform for all modified axes are in correct - // // relative order, matching how it was in replay_axis_map - // producer_replay_root = producer_replay_root->reorder(root_axis_map); - - // // Finally replay producer as consumer on marked axes - // TensorDomain* replayed = TransformIter::replay( - // producer_replay_root, - // TransformIter::getHistory(consumer), - // replay_axis_map); - - // TORCH_INTERNAL_ASSERT( - // std::none_of( - // replayed->domain().begin(), - // replayed->domain().begin() + consumer_compute_at_axis, - // [](IterDomain* id) { return id->isReduction(); }), - // "Reduction axes found within consumer_compute_at_axis in replay of - // producer."); - - // return replayed; } // Replay consumer as producer. @@ -372,7 +370,7 @@ TensorDomain* TransformReplay::replayCasP( TORCH_INTERNAL_ASSERT( producer_compute_at_axis >= 0 && (unsigned int)producer_compute_at_axis <= producer->nDims(), - "Invalid axis in transform replayPasC."); + "Invalid axis in transform replayCasP."); // producer ids we need to match in consumer std::vector producer_ids; @@ -397,6 +395,10 @@ TensorDomain* TransformReplay::replayCasP( // Grab root domains of producer and consumer std::vector consumer_root = consumer->rootDomain(); std::vector producer_root = producer->rootDomain(); + // If producer has an rfactor root, that's the one that will match the + // consumer + if (producer->hasRFactor()) + producer_root = producer->rfactorDomain(); // Track which root axes in consumer we send to replay std::unordered_set consumer_mapped_roots; @@ -427,8 +429,14 @@ TensorDomain* TransformReplay::replayCasP( } } + // Instead of replaying from the root, lets try to forward the history of + // consumer if they match ops on producer. Enforce if we modify an rfactor + // axis that those ops match. + root_axis_map = BestEffortReplay::replay( + consumer->domain(), producer->domain(), root_axis_map); + // Replay producer dimensions. - ReplayTransformations replay_consumer(producer_ids, root_axis_map); + ReplayTransformations replay_consumer(producer_ids, root_axis_map, false); // replay_consumer now contains mappings from producer axes to their replayed // counter parts in consumer (including intermediate IDs, not just those in @@ -449,31 +457,43 @@ TensorDomain* TransformReplay::replayCasP( continue; IterDomain* inp_id = static_cast(inp); - // if ( we sent this root id to replay && it was modified ) - if (consumer_mapped_roots.find(inp_id) != consumer_mapped_roots.end() && - replay_consumer.getUnorderedLeafIDs().find(inp_id) == - replay_consumer.getUnorderedLeafIDs().end()) { + // if ( we sent this root id to be replayed ) + if (consumer_mapped_roots.find(inp_id) != consumer_mapped_roots.end()) { modified = true; break; } } - // If not impacted, lets put it back in the replayed domain. + // If none of the inputs were replayed, we want to insert the id in the + // replayed domain. if (!modified) unmodified_consumer_axes.emplace_back(consumer_id); } - // (1) replay_consumer.getReplay holds mappings from axes in producer -> - // generated axes in consumer (2) replay_consumer.getLeafIDs holds a - // determinstica ordering of axes in (1), and all other leaf axes created in - // generating the above (3) unmodified_consumer_axes holds axes that didn't - // have to be modified to generate (1) + /* + * Accumulate axes in to the new domain in the following order, making sure to + * avoid any duplicates: + * + * (1) replay_consumer.getReplay holds mappings from axes in producer -> + * generated axes in consumer + * + * (2) replay_consumer.getLeafIDs holds a determinstic ordering of axes in + * (1), and all other leaf axes created in generating the above. Next will be + * any leaves that can be mapped to the original producer domain. This is not + * an order we should guarentee but can make life simpler. + * + * (3) Any axes in getLeafIds that were in the original consumer domain. + * + * (4) Remaining leaf axes. + * + * (5) unmodified_consumer_axes holds axes that didn't have to be modified to + * generate (1) + */ - // Accumulate new domain in this vector: std::vector new_IDs; + std::unordered_set used_IDs; // Add axes in (1) - std::unordered_set used_IDs; for (auto p_id : producer_ids) { auto it = replay_consumer.getReplay().find(p_id); TORCH_INTERNAL_ASSERT( @@ -481,149 +501,46 @@ TensorDomain* TransformReplay::replayCasP( "Could not find axis, ", p_id, ", requested in replay."); - new_IDs.push_back((*it).second); - used_IDs.emplace((*it).second); + new_IDs.push_back(it->second); + used_IDs.emplace(it->second); } // Add axes in (2) + for (auto p_id : producer->domain()) { + auto it = replay_consumer.getReplay().find(p_id); + if (it == replay_consumer.getReplay().end()) + continue; + auto id = it->second; + if (used_IDs.find(id) == used_IDs.end()) { + new_IDs.push_back(it->second); + used_IDs.emplace(it->second); + } + } + + // Add axes in (3) + for (auto id : consumer->domain()) + if (replay_consumer.getUnorderedLeafIDs().find(id) != + replay_consumer.getUnorderedLeafIDs().end() && + used_IDs.find(id) == used_IDs.end()) { + new_IDs.push_back(id); + used_IDs.emplace(id); + } + + // Add axes in (4) for (auto leaf : replay_consumer.getLeafIDs()) if (used_IDs.find(leaf) == used_IDs.end()) new_IDs.push_back(leaf); - // Add axes in (3) + // Add axes in (5) for (auto unmodified : unmodified_consumer_axes) - if (used_IDs.find(unmodified) == used_IDs.end()) - new_IDs.push_back(unmodified); + new_IDs.push_back(unmodified); - TensorDomain* replayed = new TensorDomain(consumer_root, new_IDs); + TensorDomain* replayed = + new TensorDomain(consumer_root, consumer->rfactorDomain(), new_IDs); return replayed; } -// KEEPING BELOW FOR NOW FOR REFERENCE WHEN DOING RFACTOR! -// // Want producer root with no reductions, rfactor included -// TensorDomain* producer_rfactor_root = -// TransformIter::getRFactorRoot(producer); TensorDomain* producer_root = -// TransformIter::getRoot(producer); -// // Producer root still has reductions - -// // Want full consumer root, even before rfactor -// TensorDomain* consumer_root = TransformIter::getRoot(consumer); - -// // We want to see which axes in the producer root were modified to create -// axes -// // < producer_compute_at_axis -// std::vector producer_influence(producer->nDims(), false); -// for (int i = 0; i < producer_compute_at_axis; i++) -// producer_influence[i] = true; - -// // Check which axes in ref_root need to be modified to honor -// transformations -// // to compute at axis -// std::vector producer_root_influence = -// TransformIter::getRootInfluence(producer, producer_influence); - -// for (decltype(producer_root->nDims()) i{0}; i < producer_root->nDims(); -// i++) { -// TORCH_INTERNAL_ASSERT( -// !(producer_root_influence[i] && -// producer_root->axis(i)->isReduction()), "Error during replay, likely -// due to an illegal bad computeAt."); -// } - -// std::vector producer_rfactor_root_influence = -// TransformIter::replayInfluence( -// TransformIter::getHistory(producer_rfactor_root), -// producer_root_influence); - -// // We have the influence on the producer root, we need it on consumer, we -// // want to keep those axes that don't need to be modified by the replay -// std::vector consumer_root_influence( -// consumer->rootDomain()->nDims(), false); - -// // Producer -> consumer axis map -// std::vector replay_axis_map(producer_rfactor_root->nDims(), -1); - -// // Setup consumer_root_influence vector on root for replay -// decltype(consumer_root_influence.size()) ip = 0, ic = 0; -// while (ic < consumer_root_influence.size() && -// ip < producer_rfactor_root->nDims()) { -// bool is_reduction = producer_rfactor_root->axis(ip)->isReduction(); -// if (is_reduction) { -// replay_axis_map[ip++] = -1; -// continue; -// } -// if (producer_rfactor_root_influence[ip] && -// !consumer_root->axis(ic)->isRFactorProduct()) { -// replay_axis_map[ip] = ic; -// } else { -// replay_axis_map[ip] = -1; -// } -// consumer_root_influence[ic++] = producer_rfactor_root_influence[ip++]; -// } - -// // Unlike PasC if last axes in producer_rfactor_root is a reduction we -// won't -// // be guarneteed that ip == producer_rfactor_root->nDims(), that's why we -// // initialize replay_axis_map with -1 - -// TORCH_INTERNAL_ASSERT( -// consumer_root_influence.size() == consumer_root->nDims(), -// "Error detected during replay, expected matching sizes of influence map -// to root dimensions."); - -// std::vector consumer_replay_map(consumer_root->nDims()); -// for (decltype(consumer_replay_map.size()) i{0}; -// i < consumer_replay_map.size(); -// i++) { -// if (consumer_root->axis(i)->isRFactorProduct()) { -// consumer_replay_map[i] = i; -// } else { -// consumer_replay_map[i] = consumer_root_influence[i] ? -1 : i; -// } -// } - -// // Replay axes that won't be modified by transform replay -// TensorDomain* consumer_replay_root = TransformIter::replaySelf( -// consumer, TransformIter::getHistory(consumer), consumer_replay_map); - -// // Record axes positions. -// std::unordered_map new_position; -// for (decltype(consumer_replay_root->nDims()) i{0}; -// i < consumer_replay_root->nDims(); -// i++) -// new_position[consumer_replay_root->axis(i)] = i; - -// std::unordered_map root_axis_map; -// // reorder consumer_replay_root to respect replay_axis_map -// for (decltype(replay_axis_map.size()) i{0}; i < replay_axis_map.size(); -// i++) { -// if (replay_axis_map[i] == -1) -// continue; -// auto ax = consumer_root->axis(replay_axis_map[i]); -// TORCH_INTERNAL_ASSERT( -// new_position.find(ax) != new_position.end(), -// "Error hit during transform replay, could not find ", -// ax, -// " expected in root domain."); -// root_axis_map[new_position[ax]] = replay_axis_map[i]; -// } - -// auto replay_history = TransformIter::getHistory(producer); -// auto rfactor_history = TransformIter::getHistory(producer_rfactor_root); -// replay_history.erase( -// replay_history.begin(), replay_history.begin() + -// rfactor_history.size()); - -// consumer_replay_root = consumer_replay_root->reorder(root_axis_map); -// // Finally replay consumer as producer on marked axes - -// auto replayed = TransformIter::replay( -// consumer_replay_root, replay_history, replay_axis_map); - -// return replayed; -// } - // replay Producer as Consumer TensorView* TransformReplay::replayPasC( TensorView* producer, diff --git a/torch/csrc/jit/codegen/cuda/transform_replay_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_replay_rfactor.h new file mode 100644 index 0000000000000..ccca6175f597d --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/transform_replay_rfactor.h @@ -0,0 +1,229 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { + +namespace { + +// Simply grabs all exprs needed to produce provided outputs. +struct Exprs : public IterVisitor { + private: + std::vector exprs; + void handle(Expr* e) override { + exprs.push_back(e); + } + + public: + static std::vector getFrom(std::vector outputs) { + if (outputs.empty()) + return std::vector(); + + Exprs inst; + inst.traverseFrom(outputs[0]->fusion(), outputs); + return inst.exprs; + } +}; + +} // namespace + +/* + * Consider the following program: + * + * T1[I0, R1] = T0[I0, I1] + * T2[I0] = T1[I0, R1i] + * + * T1->split(1, factor) + * T1->rFactor(2) + * + * T4[I0, R1orf, I1irf] = T0[I0, I1] + * T1[I0, R1i] = T4[I0, R1orf, I1irf] + * T2[I0] = T1[I0, R1i] + * + * There's an issue when we call replayCasP on + * T4[I0, R1o, I1i] = T0[I0, I1] + * + * This would try to replay T4 as T0, and it could include the rfactor domains. + * For example we compute T0 inline with T4. The way computeAt is setup this + * would call replayPasC(T0, T4, -1) then repalyCasP(T4, T0, -1) + * + * We might assume that the only way we will hit this is if we call + * T4->computeAt(T0...) so it might be safe to assume that the right + * transformations would be replayed. However, we want to preserve the rfactor + * domain, so since it would replay T4 at root, it would produce iterdomains + * that wouldn't corresopnd to those in rfactor. Also, I don't know if this + * assumption is correct. + * + * Therefore, we will assume it is not correct, and we will validate here that + * if we replay a domain that it would transform it in a way consistent with + * any defined RFactor domains, then we will update the replay map so that + * RFactor roots are mapped to intermediate IterDomains in the target and start + * replay from there. + * + * This class will validate/do the above. It will also run through + * transformations in target according to root_map. If equal transformations + * already exist in replay_domain history, we will not redo those + * transformations, but instead update root_map to reflect forwarding the + * existing transformations. This later part is the "best effort" replay. Though + * we include rfactor replay and validation here. + */ + +struct TORCH_CUDA_API BestEffortReplay { + static std::unordered_map replay( + const std::vector& replay_domain, + const std::vector& target_domain, + const std::unordered_map& root_map) { + std::vector t_exprs = Exprs::getFrom( + std::vector(target_domain.begin(), target_domain.end())); + + // If we check how an IterDomain was generated, it should only use an + // IterDomain in an expression once. We pull a map from the input + // IterDomains to the expression consuming them to generate the + // replay_domain domain. This will be used to propagate the target_domain to + // replay_domain map. + + std::vector r_exprs = Exprs::getFrom( + std::vector(replay_domain.begin(), replay_domain.end())); + std::unordered_map replay_expr_map; + for (auto r_expr : r_exprs) + for (auto inp : r_expr->inputs()) + if (inp->getValType().value() == ValType::IterDomain) { + auto id = static_cast(inp); + TORCH_INTERNAL_ASSERT( + replay_expr_map.find(id) == replay_expr_map.end(), + "Error trying to map rfactor root domain during replay. IterDomain's shouldn't have more than one use."); + // Only want to forward rfactor in map + replay_expr_map[id] = r_expr; + } + + // Map we can update as we progress through exprs. We will only maintaing + // leaf nodes. + std::unordered_map updated_id_map( + root_map.begin(), root_map.end()); + + std::string err_str( + "Error during replay, a computeAt was called that conflicts with an rfactor call."); + + for (auto t_expr : t_exprs) { + // Going to map the target_domain inputs/outputs to replay_domain + // inputs/outputs + std::vector r_inps; + std::vector t_inps; + + for (auto inp : t_expr->inputs()) { + if (inp->getValType() == ValType::IterDomain) { + auto t_inp = static_cast(inp); + t_inps.push_back(t_inp); + // There might not be a mapping, that could be okay. + auto it = updated_id_map.find(t_inp); + if (it != updated_id_map.end()) + r_inps.push_back(it->second); + } + } + + bool has_rfactor = + std::any_of(r_inps.begin(), r_inps.end(), [](IterDomain* id) { + return id->isRFactorProduct(); + }); + + if (r_inps.size() != t_inps.size() || r_inps.empty()) { + // If any replay_domain inputs are an rfactor product, all inputs should + // match. + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + + if (replay_expr_map.find(r_inps[0]) == replay_expr_map.end()) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + + auto r_expr = replay_expr_map[r_inps[0]]; + bool mismatched_inputs = false; + { + size_t i = 0; + for (auto r_inp : r_expr->inputs()) { + if (i > r_inps.size()) { + mismatched_inputs = true; + break; + } + mismatched_inputs = mismatched_inputs || r_inp != r_inps[i]; + i++; + } + } + + if (mismatched_inputs) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + + if (t_expr->nOutputs() != r_expr->nOutputs()) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + + if (r_expr->getExprType().value() != t_expr->getExprType().value()) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + + // If the expression is a split, make sure it's split by the same ammount. + if (r_expr->getExprType().value() == ExprType::Split) { + if (!static_cast(r_expr)->factor()->sameAs( + static_cast(r_expr)->factor())) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + } + + bool missing_input = std::any_of( + t_expr->inputs().begin(), + t_expr->inputs().end(), + [updated_id_map](Val* inp) { + if (inp->getValType() == ValType::IterDomain) { + return updated_id_map.find(static_cast(inp)) == + updated_id_map.end(); + } + return false; + }); + + if (missing_input) { + TORCH_INTERNAL_ASSERT(!has_rfactor, err_str); + continue; + } + // Take target_domain inputs out of map: + for (auto inp : t_expr->inputs()) { + if (inp->getValType() == ValType::IterDomain) { + auto t_inp = static_cast(inp); + auto it = updated_id_map.find(t_inp); + updated_id_map.erase(it); + } + } + + // Add outputs to map. + for (size_t i = 0; i < t_expr->nOutputs(); i++) { + auto t_out = t_expr->output(i); + auto r_out = r_expr->output(i); + if (t_out->getValType() == ValType::IterDomain && + r_out->getValType() == ValType::IterDomain) { + updated_id_map[static_cast(t_out)] = + static_cast(r_out); + } + } + } + + return updated_id_map; + } +}; + +} // namespace fuser +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index 425568fdd4d71..0b40956e6e3f0 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -1,230 +1,395 @@ #include +#include #include #include +#include namespace torch { namespace jit { namespace fuser { -// TensorDomain* TransformRFactor::runReplay( -// TensorDomain* orig_td, -// std::vector axes) { -// int ndims = (int)orig_td->nDims(); - -// // Adjust and check provided axes -// std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { -// TORCH_CHECK( -// i >= -ndims && i < ndims, -// "Rfactor replay recieved an axis outside the number of dims in the -// tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); -// return i < 0 ? i + ndims : i; -// }); - -// // remove duplicates, and put into a set for searching -// std::set axes_set(axes.begin(), axes.end()); - -// // Make a copy of orig_td as we're going to change its history: -// bool found_rfactor = false; -// std::vector domain_copy; -// for (int i{0}; i < ndims; i++) { -// IterDomain* orig_axis = orig_td->axis(i); -// if (axes_set.find(i) != axes_set.end()) -// TORCH_CHECK( -// orig_axis->isReduction(), -// "Tried to rFactor an axis that is not a reduction."); - -// if (orig_axis->isReduction()) { -// if (axes_set.find(i) == axes_set.end()) { -// domain_copy.push_back(new IterDomain( -// orig_axis->start(), -// orig_axis->extent(), -// orig_axis->parallel_method(), -// false, -// true)); -// found_rfactor = true; -// } else { -// domain_copy.push_back(new IterDomain( -// orig_axis->start(), -// orig_axis->extent(), -// orig_axis->parallel_method(), -// true, -// true)); -// } -// } else { -// domain_copy.push_back(orig_td->axis(i)); -// } -// } -// TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); - -// // TD that we will actually modify, -// TensorDomain* out_td = new TensorDomain(domain_copy); - -// // Axis map to create history for non-rfactor axes -// std::vector axis_map(ndims, -1); -// std::vector orig_rfactor_axis_map(ndims, -1); -// std::set rfactor_ids; -// for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) -// if (!out_td->axis(i)->isRFactorProduct()) { -// axis_map[i] = i; -// } else { -// orig_rfactor_axis_map[i] = i; -// } - -// // Replay non-rfactor axes -// auto running_td = TransformIter::replayBackward( -// out_td, TransformIter::getHistory(orig_td), axis_map); - -// // running_td has iteration domains on the right, but to find a valid -// rfactor -// // root, we want those to be on the right. If we continued to replay -// backward -// // we likely won't have a valid rfactor root. Lets manually insert a so we -// // have a valid rfactor root. - -// std::vector new2old(running_td->nDims()); -// { -// int running_pos = 0; -// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) -// if (!running_td->axis(i)->isRFactorProduct()) -// new2old[i] = running_pos++; - -// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) -// if (running_td->axis(i)->isRFactorProduct()) -// new2old[i] = running_pos++; -// } - -// // how do we find axes -// // Need axis map from rfactor axes in running_td to corresponding axes in -// // orig_td orig_rfactor_axis_map goes from orig_td to out_td we want it to -// // go from orig_td to running_td - -// // Go from IterDomain to its position in running_td -// std::unordered_map new_pos; -// for (decltype(running_td->nDims()) i{0}; i < running_td->nDims(); i++) { -// new_pos[running_td->axis(i)] = i; -// } - -// for (decltype(out_td->nDims()) i{0}; i < out_td->nDims(); i++) -// if (orig_rfactor_axis_map[i] != -1) { -// // int orig_td_pos = i; -// int out_td_pos = orig_rfactor_axis_map[i]; -// TORCH_INTERNAL_ASSERT( -// new_pos.find(out_td->axis(out_td_pos)) != new_pos.end(), -// "Error aligning axes in rfactor first TD replay."); -// int running_td_pos = new_pos[out_td->axis(out_td_pos)]; -// orig_rfactor_axis_map[i] = running_td_pos; -// } - -// TransformIter::replayBackward( -// running_td, TransformIter::getHistory(orig_td), orig_rfactor_axis_map); - -// return out_td; -// } - -// TensorDomain* TransformRFactor::runReplay2( -// TensorDomain* in_td, -// std::vector axes) { -// int ndims = (int)in_td->nDims(); - -// // Adjust and check provided axes -// std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { -// TORCH_CHECK( -// i >= -ndims && i < ndims, -// "Rfactor replay recieved an axis outside the number of dims in the -// tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); -// return i < 0 ? i + ndims : i; -// }); - -// // remove duplicates, and put into a set for searching -// std::set axes_set(axes.begin(), axes.end()); - -// bool found_rfactor = false; -// // Axes marked as rfactor, these will be removed from this domain -// std::vector rfactor_axes(in_td->nDims(), false); -// for (int i{0}; i < ndims; i++) { -// bool in_set = axes_set.find(i) != axes_set.end(); -// IterDomain* orig_axis = in_td->axis(i); - -// if (in_set) { -// TORCH_CHECK( -// orig_axis->isReduction(), -// "Tried to rFactor an axis that is not a reduction."); -// rfactor_axes[i] = true; -// found_rfactor = true; -// } -// } - -// TORCH_CHECK(found_rfactor, "Could not find axis to rfactor out."); -// auto root_rfactor_axes = TransformIter::getRootInfluence(in_td, -// rfactor_axes); - -// // Root axes involved in rfactor, these axes should not be replayed, they -// need -// // to be either removed completely, or part of the root domain -// auto root_dom = TransformIter::getRoot(in_td); -// TORCH_INTERNAL_ASSERT( -// root_rfactor_axes.size() == root_dom->nDims(), -// "Error backpropagating influence of rfactor."); - -// // Forward propagate influence back to the end we want to mark everything -// // that's part of the rfactor -// rfactor_axes = TransformIter::replayInfluence( -// TransformIter::getHistory(in_td), root_rfactor_axes); - -// // Axes part of rfactor we need to keep -// std::vector rfactor_axes_keep; - -// for (int i{0}; i < ndims; i++) { -// if (rfactor_axes[i] && axes_set.find(i) == axes_set.end()) { -// TORCH_INTERNAL_ASSERT( -// in_td->axis(i)->isReduction(), -// "Error occured when tracking rfactor axes."); -// rfactor_axes_keep.push_back(in_td->axis(i)); -// } -// } - -// int root_ndims = (int)root_dom->nDims(); -// std::vector domain_copy; -// // These are the axes that are not involved in the rfactor. -// for (int i{0}; i < root_ndims; i++) { -// if (!root_rfactor_axes[i]) { -// domain_copy.push_back(root_dom->axis(i)); -// } -// } - -// TORCH_INTERNAL_ASSERT( -// domain_copy.size() < root_dom->nDims(), -// "Error during rfactor, didn't get any rfactor axes."); - -// // Setup axis map before we add back in the rfactor_axes -// std::vector replay_axis_map(root_dom->nDims(), -1); -// { -// decltype(domain_copy.size()) it = 0; -// decltype(root_dom->nDims()) ir = 0; -// while (it < domain_copy.size() && ir < root_dom->nDims()) { -// if (root_rfactor_axes[ir]) { -// ir++; -// } else { -// replay_axis_map[ir++] = it++; -// } -// } -// TORCH_INTERNAL_ASSERT( -// it == domain_copy.size(), -// "Error during rfactor, missed an unmodified root domain."); -// } - -// // Push back the rfactor axes we need to keep -// domain_copy.insert( -// domain_copy.end(), rfactor_axes_keep.begin(), rfactor_axes_keep.end()); - -// // TD that we will actually modify -// TensorDomain* replay_root_td = new TensorDomain(domain_copy); -// auto td = TransformIter::replay( -// replay_root_td, TransformIter::getHistory(in_td), replay_axis_map); - -// return td; -// } +namespace { + +struct ReplayRFactor : public ReplayTransformations { + private: + // Took a good bit of this from ReplayTransformations::handle(Split...) + void handle(Split* s) override { + // Grab input to the split operation + auto id_in = s->in(); + // Grab our mapping of that ID to the one we're replaying + auto it = id_map_.find(id_in); + // Make sure it exists in the map + TORCH_INTERNAL_ASSERT( + it != id_map_.end(), + "Transform traversal failed, dependencies not met."); + // Grab the ID we're going to replay on + auto mapped = (*it).second; + TORCH_INTERNAL_ASSERT( + s->factor()->isConst(), + "Transform traversal does not support splitting on non-const values."); + // This ID should be a leaf ID (meaning it has no uses we generated) + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(mapped) != leaf_ids_.end(), + "Transform traversal failed, modified a node but it was not a leaf node."); + + // Check if either outputs of the split are going to be an rfactored axis + bool rfactor_outer = false; + bool rfactor_inner = false; + if (rfactor_axes_.find(s->outer()) != rfactor_axes_.end()) + rfactor_outer = true; + + if (rfactor_axes_.find(s->inner()) != rfactor_axes_.end()) + rfactor_inner = true; + + bool rfactor_input = mapped->isRFactorProduct(); + + // If nothing is going to be rfactored replay a normal split + if (!rfactor_inner && !rfactor_outer && !rfactor_input) + return ReplayTransformations::handle(s); + + // outer loop size + Val* oe = ceilDiv(mapped->extent(), s->factor()); + + // Manually replay the split, making reduction = false and rfactor = true + // outer IterDomain + IterDomain* ido = new IterDomain( + new Int(0), + static_cast(oe), + mapped->parallel_method(), + rfactor_outer, + true, + mapped->isBroadcast()); + + // inner IterDomain + IterDomain* idi = new IterDomain( + new Int(0), + s->factor(), + mapped->parallel_method(), + rfactor_inner, + true, + mapped->isBroadcast()); + + // Generate the split node + new Split(ido, idi, mapped, s->factor()); + + // Remove mapped id from leaf IDs + leaf_ids_.erase(mapped); + // Add outputs to leaf IDs + leaf_ids_[ido] = counter++; + leaf_ids_[idi] = counter++; + + // Update our ID map to include these outputs + id_map_[s->outer()] = ido; + id_map_[s->inner()] = idi; + } + + void handle(Merge* m) override { + auto id_outer = m->outer(); + auto id_inner = m->inner(); + auto it_outer = id_map_.find(id_outer); + auto it_inner = id_map_.find(id_inner); + TORCH_INTERNAL_ASSERT( + it_outer != id_map_.end() && it_inner != id_map_.end(), + "Transform traversal failed, dependencies not met."); + + auto id_outer_mapped = (*it_outer).second; + auto id_inner_mapped = (*it_inner).second; + + TORCH_INTERNAL_ASSERT( + leaf_ids_.find(id_outer_mapped) != leaf_ids_.end() && + leaf_ids_.find(id_inner_mapped) != leaf_ids_.end(), + "Transform traversal failed, modified ", + id_outer_mapped, + " and ", + id_inner_mapped, + " however one or both are not leaf nodes."); + + bool rfactor_output = false; + if (rfactor_axes_.find(m->out()) != rfactor_axes_.end()) + rfactor_output = true; + + bool rfactor_input = id_inner_mapped->isRFactorProduct() || + id_outer_mapped->isRFactorProduct(); + + if (!rfactor_output && !rfactor_input) + return ReplayTransformations::handle(m); + + Val* merged_id_size = + mul(id_outer_mapped->extent(), id_inner_mapped->extent()); + IterDomain* merged_id = new IterDomain( + new Int(0), + static_cast(merged_id_size), + id_outer_mapped->parallel_method(), + rfactor_output, + true, + id_outer_mapped->isBroadcast() && id_inner_mapped->isBroadcast()); + + new Merge(merged_id, id_outer_mapped, id_inner_mapped); + + // Remove inputs from the leaf IDs + leaf_ids_.erase(id_outer_mapped); + leaf_ids_.erase(id_inner_mapped); + + // Add the output to the leaf IDs + leaf_ids_[merged_id] = counter++; + + id_map_[m->out()] = merged_id; + } + + std::unordered_set rfactor_axes_; + + public: + ReplayRFactor( + const std::vector& _target_domain, + std::unordered_map _id_map, + std::unordered_set _rfactor_axes) + : ReplayTransformations(_target_domain, _id_map, false), + rfactor_axes_(std::move(_rfactor_axes)) {} +}; + +} // namespace + +// Take any axes not provided, that are reductions, and convert them to +// iteration axes. Any axes that share inputs to the axes provided should be +// marked as rfactorProduct. +TensorDomain* TransformRFactor::runReplay( + TensorDomain* orig_td, + std::vector axes) { + TORCH_CHECK(!axes.empty(), "No axes provided to rfactor replay."); + + int ndims = (int)orig_td->nDims(); + + // Adjust and check provided axes + std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { + TORCH_CHECK( + i >= -ndims && i < ndims, + "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", + -ndims, + " to ", + ndims - 1); + return i < 0 ? i + ndims : i; + }); + + // remove duplicates, and put into a set for searching + std::unordered_set axes_set(axes.begin(), axes.end()); + + TORCH_INTERNAL_ASSERT( + std::all_of( + axes_set.begin(), + axes_set.end(), + [orig_td](int i) { return orig_td->axis(i)->isReduction(); }), + "Cannot rfactor axes that are not reduction axes."); + + // RFactor requires at least one reduction axis to be marked as factored out, + // and at least one reduction axis that won't. Otherwise it's just a pointwise + // cacheing operation. + bool found_non_rfactor_reduction = false; + + // Make a set of final axes that are marked to be rfactored + std::unordered_set rfactor_axes(axes_set.size()); + { + size_t i = 0; + for (auto id : orig_td->domain()) { + if (axes_set.find(i++) != axes_set.end()) + rfactor_axes.emplace(id); + if (id->isReduction()) + found_non_rfactor_reduction = true; + } + } + + TORCH_CHECK( + found_non_rfactor_reduction, + "Must have at least one reduction axis not marked as rfactor."); + + // Get root IterDomains of the rfactor domains, these will be the ones we will + // replay marked as rfactor axes, those marked in the axes set will be + // reduction=false + + auto rfactor_root_vals = IterVisitor::getInputsTo( + std::vector(rfactor_axes.begin(), rfactor_axes.end())); + + // Make sure they're all IterDomains. + TORCH_INTERNAL_ASSERT( + std::all_of( + rfactor_root_vals.begin(), + rfactor_root_vals.end(), + [](Val* v) { + return v->getValType().value() == ValType::IterDomain; + }), + "Found invalid input domain axes."); + + // Put in a set to make searching easy + std::unordered_set rfactor_root_axes; + std::transform( + rfactor_root_vals.begin(), + rfactor_root_vals.end(), + std::inserter(rfactor_root_axes, rfactor_root_axes.end()), + [](Val* val) { + TORCH_INTERNAL_ASSERT( + val->getValType().value() == ValType::IterDomain, + "Invalid value type found in rfactor axes inputs."); + return static_cast(val); + }); + + auto orig_td_root = orig_td->rootDomain(); + + // Generate a new TensorDomain and set up map from one root to this one. + size_t rfactor_axes_found = 0; + std::vector new_root(orig_td_root.size(), nullptr); + std::unordered_map replay_map; + + { + size_t i = 0; + for (auto id : orig_td_root) { + // If this is an rfactor root, it will be a reduction in this stage + if (rfactor_root_axes.find(id) != rfactor_root_axes.end()) { + new_root[i] = new IterDomain( + id->start(), + id->extent(), + id->parallel_method(), + true, + true, + false); + // If this is not an rfactor root, but a reduction root, it should be + // turned into an iteration domain + } else if (id->isReduction()) { + new_root[i] = new IterDomain( + id->start(), + id->extent(), + id->parallel_method(), + false, + true, + false); + } else { + new_root[i] = id->clone(); + } + replay_map[id] = new_root[i++]; + } + } + + // Replay producer dimensions. + ReplayRFactor replay_rfactor(orig_td->domain(), replay_map, rfactor_axes); + + std::unordered_map replayed = + replay_rfactor.getReplay(); + + std::vector new_domain(orig_td->nDims(), nullptr); + { + size_t i = 0; + for (auto id : orig_td->domain()) { + TORCH_INTERNAL_ASSERT( + replayed.find(id) != replayed.end(), + "Error during rfactor replay, missing an axis."); + new_domain[i++] = replayed[id]; + } + } + + // We need a root to match up with the consumer of this domain, it should have + // rfactor axes after transformations, but not other axes. + std::vector rfactor_root; + for (auto dom : new_root) + if (!dom->isRFactorProduct()) + rfactor_root.push_back(dom); + + for (auto dom : new_domain) + if (dom->isRFactorProduct()) + rfactor_root.push_back(dom); + + return new TensorDomain(new_root, rfactor_root, new_domain); +} + +// We want to take any axes marked in axes and remove them from the TensorDomain +// completely, any other reduction axes found should remain. +TensorDomain* TransformRFactor::runReplay2( + TensorDomain* orig_td, + std::vector axes) { + int ndims = (int)orig_td->nDims(); + + // Adjust and check provided axes + std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { + TORCH_CHECK( + i >= -ndims && i < ndims, + "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", + -ndims, + " to ", + ndims - 1); + return i < 0 ? i + ndims : i; + }); + + // remove duplicates, and put into a set for searching + std::set axes_set(axes.begin(), axes.end()); + + // Grab the axes in the rfactor, these were converted to iter domains in the + // producer of this domain, and will be reduced in this domain + std::unordered_set rfactor_axes(axes_set.size()); + { + size_t i = 0; + for (auto id : orig_td->domain()) { + if (axes_set.find(i++) != axes_set.end()) + rfactor_axes.emplace(id); + } + } + + auto rfactor_root_vals = IterVisitor::getInputsTo( + std::vector(rfactor_axes.begin(), rfactor_axes.end())); + + // Make sure they're all IterDomains. + TORCH_INTERNAL_ASSERT( + std::all_of( + rfactor_root_vals.begin(), + rfactor_root_vals.end(), + [](Val* v) { + return v->getValType().value() == ValType::IterDomain; + }), + "Found invalid input domain axes."); + + // Put in a set to make searching easy + std::unordered_set rfactor_root_axes; + std::transform( + rfactor_root_vals.begin(), + rfactor_root_vals.end(), + std::inserter(rfactor_root_axes, rfactor_root_axes.end()), + [](Val* val) { + TORCH_INTERNAL_ASSERT( + val->getValType().value() == ValType::IterDomain, + "Invalid value type found in rfactor axes inputs."); + return static_cast(val); + }); + + // Replay all other root domains that are iter domains, as these will match in + // the domain we're creating + std::vector new_root; + std::unordered_map replay_root_map; + for (auto id : orig_td->rootDomain()) { + if (rfactor_root_axes.find(id) == rfactor_root_axes.end()) { + new_root.push_back(id->clone()); + replay_root_map[id] = new_root.back(); + } + } + + ReplayTransformations rt(orig_td->domain(), replay_root_map, false); + auto replayed = rt.getReplay(); + + std::vector new_domain; + + { + // Construct the new domain, and append rfactor axes to the new root domain + size_t i = 0; + for (auto id : orig_td->domain()) { + if (replayed.find(id) != replayed.end()) { + new_domain.push_back(replayed[id]); + } else if (axes_set.find(i) == axes_set.end()) { + IterDomain* new_id = id->clone(); + new_domain.push_back(new_id); + new_root.push_back(new_id); + } + i++; + } + } + + return new TensorDomain(new_root, new_domain); +} } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.h b/torch/csrc/jit/codegen/cuda/transform_rfactor.h index 64e3e6b8b0152..9fc2e15d39dba 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.h +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.h @@ -18,12 +18,9 @@ struct TORCH_CUDA_API TransformRFactor { public: // Create a copy of td, change its history by presrving axes so they appear in // the root domain - static TensorDomain* runReplay(TensorDomain*, std::vector axes) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); - } - static TensorDomain* runReplay2(TensorDomain*, std::vector axes) { - TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); - } + static TensorDomain* runReplay(TensorDomain*, std::vector axes); + + static TensorDomain* runReplay2(TensorDomain*, std::vector axes); }; } // namespace fuser