From e5e4f342ebcf95815a5f03cc96329534170ff048 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 11 Oct 2022 08:28:04 -0400 Subject: [PATCH 1/2] Support more complex view patterns in pointwise ops. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 159 +++++++++++++++++- torch/csrc/jit/codegen/cuda/compute_at_map.h | 58 ++++++- .../jit/codegen/cuda/scheduler/pointwise.cpp | 33 +--- .../jit/codegen/cuda/scheduler/registry.cpp | 138 ++++++++++++++- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 84 ++++++++- torch/csrc/jit/codegen/cuda/scheduler/utils.h | 5 + 6 files changed, 438 insertions(+), 39 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 7f3de6687eb3a..5c2e98da55c61 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -165,6 +165,14 @@ bool IterDomainGraph::exprsMap( return true; } +// Given first and second Exprs "match" +// Expr type matches +// IterDomain's in the inputs and outputs exact match, (including argument +// position positions) +// Paramters like Split's factor "match" (exact match on integers could be +// better, as today it will just check it's the same symbol or evaluated to +// the same constant. However, we know all the extents of all the +// IterDomain's that exact map with eachother are the same value. void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { if (first == nullptr || second == nullptr) { return; @@ -194,7 +202,37 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { namespace { -// Returns a pair of mapped IDs +// Returns the first pair of id's in ids detected to match eachother on the +// permissive map of the ID graph. TODO: what this is really looking for is if +// there's any overlapping between the iter domains in the provided set. +// +// i.e. if we have: +// tv0 = arange(6).view({3, 2}) +// tv1 = tv0[3, 2].t() +// tv2 = tv0[3, 2].view({2, 3}) +// tv3 = tv1 + tv2 +// +// Then we can see this overlap in the tv3 expression as: +// +// tv0 = { {0, 1, 2}, +// {3, 4, 5} } +// +// tv1 = { {0, 3}, +// {1, 4}, +// {2, 5} } +// +// tv2 = { {0, 1}, +// {2, 3}, +// {4, 5} } +// +// The elements in tv1 {3, 1, 4, 2}, map respectively to the elements in tv2 {1, +// 2, 3, 4}. The reason this is so important is it means that generating tv3 is +// no longer a trivially parallelizable problem (if we include the dag all the +// way to tv0). So tv0's axes cannot be inlined across both the tv0 and tv1 +// path. This breaks some assumptions we have today in schedulers that will +// assume tv2 can be trivially inlined/parallelized. Instead we'd need to take +// into consideration the effective communication going on here, so that we pull +// multiple values of tv0 to compute tv3. c10::optional> detectMappablePair( const std::vector& ids, const IterDomainGraph& id_graph) { @@ -637,6 +675,7 @@ ComputeAtMap::ComputeAtMap(Fusion* fusion) void ComputeAtMap::build(Fusion* fusion) { trivial_reduction_info_.build(fusion); buildConcreteIds(); + buildUniqueExactExprMaps(); } void ComputeAtMap::validateAndPropagatePType() { @@ -1052,6 +1091,124 @@ void ComputeAtMap::buildConcreteIds() { } } +bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) { + if (expr_1->getExprType() != expr_2->getExprType()) { + return false; + } + + if (expr_1->isA()) { + auto swizzle_1 = expr_1->as(); + auto swizzle_2 = expr_2->as(); + if (swizzle_1->swizzleType() != swizzle_2->swizzleType() || + swizzle_1->swizzleMode() != swizzle_2->swizzleMode()) { + return false; + } + } + + TORCH_INTERNAL_ASSERT( + expr_1->inputs().size() == expr_2->inputs().size() && + expr_1->outputs().size() == expr_2->outputs().size(), + "Expr traversal doesn't support variable number of inputs and outputs."); + + for (auto input_i : c10::irange(expr_1->inputs().size())) { + if (expr_1->inputs()[input_i]->isA() && + !areMapped( + expr_1->inputs()[input_i]->as(), + expr_2->inputs()[input_i]->as(), + IdMappingMode::EXACT)) { + // Inputs don't exact map in the right order + return false; + } + } + + for (auto output_i : c10::irange(expr_1->outputs().size())) { + if (expr_1->outputs()[output_i]->isA() && + !areMapped( + expr_1->outputs()[output_i]->as(), + expr_2->outputs()[output_i]->as(), + IdMappingMode::EXACT)) { + // Outputs don't exact map in the right order + return false; + } + } + // Expr's are almost exact mapped transforms + return true; +} + +void ComputeAtMap::buildUniqueExactExprMaps() { + // Start by building definitions + for (const auto& disjoint_set_shared_ptr : + id_graph_.exactNodes().disjointSets()) { + std::vector definitions; + + // N^2 in number of unique transformations, this might be better to do + // when generating the map. + IterDomain* concrete_id = nullptr; + for (auto id : disjoint_set_shared_ptr->vector()) { + if (concrete_id == nullptr) { + concrete_id = getConcreteMappedID(id, IdMappingMode::EXACT); + } + + if (id->definition() != nullptr) { + bool match = false; + for (auto recorded_def : definitions) { + if (areExactExprs(id->definition(), recorded_def)) { + match = true; + break; + } + } + if (!match) { + definitions.push_back(id->definition()); + } + } + } + unique_exact_definitions_[disjoint_set_shared_ptr] = definitions; + } + + // Use definitions to build uses + for (const auto& disjoint_set_shared_ptr : + id_graph_.exactNodes().disjointSets()) { + auto definition_it = + unique_exact_definitions_.find(disjoint_set_shared_ptr); + + if (definition_it == unique_exact_definitions_.end()) { + continue; + } + + const auto& definitions = definition_it->second; + + for (auto definition : definitions) { + auto inp_ids = ir_utils::filterByType(definition->inputs()); + for (auto inp : inp_ids) { + auto inp_disjoint_set_shared_ptr = + disjointSetOf(inp, IdMappingMode::EXACT); + // Initialize uses entry + if (unique_exact_uses_.find(inp_disjoint_set_shared_ptr) == + unique_exact_uses_.end()) { + unique_exact_uses_[inp_disjoint_set_shared_ptr] = {}; + } + + auto& uses = unique_exact_uses_.at(inp_disjoint_set_shared_ptr); + + bool already_added = false; + for (auto other_use : uses) { + if (areExactExprs(definition, other_use)) { + already_added = true; + break; + } + } + if (already_added) { + continue; + } + + if (!already_added) { + uses.push_back(definition); + } + } + } + } +} + IterDomain* ComputeAtMap::getConcreteMappedID( IterDomain* id, IdMappingMode mode) const { diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index 5ea92dff16447..c16fba9474497 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -177,6 +177,31 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! guarenteed to return iter domains in the same disjoint set. IterDomain* getConcreteMappedID(IterDomain* id, IdMappingMode mode) const; + //! Returns a list of expressions that produce the iter domains of all exact + //! mapped id's to 'id'. Expressions that are the same exact transformations + //! are deduplicated in the returned expressions. + std::vector uniqueExactDefinitions(IterDomain* id) const { + auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT); + auto unique_exact_definition_it = + unique_exact_definitions_.find(disjoint_set); + if (unique_exact_definition_it == unique_exact_definitions_.end()) { + return {}; + } + return unique_exact_definition_it->second; + } + + //! Returns a list of expressions that *use* the iter domains of all exact + //! mapped id's to 'id'. Expressions that are the same exact transformations + //! are deduplicated in the returned expressions. + std::vector uniqueExactUses(IterDomain* id) const { + auto disjoint_set = disjointSetOf(id, IdMappingMode::EXACT); + auto unique_exact_use_it = unique_exact_uses_.find(disjoint_set); + if (unique_exact_use_it == unique_exact_uses_.end()) { + return {}; + } + return unique_exact_use_it->second; + } + // Prints mapping information, forwards to an internal IterDomainGraph std::string toString() const; @@ -211,6 +236,16 @@ class TORCH_CUDA_CU_API ComputeAtMap { DoubleBufferLoopStage double_buffer_loop_stage = DoubleBufferLoopStage::NotApplicable) const; + // Returns if expr_1 and expr_2 have exact mapped IterDomains in + // inputs/outputs (order matters) and if the expressions have matching + // parameters. + bool areExactExprs(Expr* expr_1, Expr* expr_2); + + // Produce the disjoint set containing provided id with mapping mode. + const std::shared_ptr>& disjointSetOf( + IterDomain* id, + IdMappingMode mode) const; + private: // Build id_graph_ void build(Fusion* fusion); @@ -220,10 +255,8 @@ class TORCH_CUDA_CU_API ComputeAtMap { IterDomain* computeConcreteId(IterDomain* id, IdMappingMode mode); void buildConcreteIds(); - // Produce the disjoint set containing provided id with mapping mode. - const std::shared_ptr>& disjointSetOf( - IterDomain* id, - IdMappingMode mode) const; + // Relies on concrete_id_cache_, buildConcreteIds() must be run before this. + void buildUniqueExactExprMaps(); // Should be built once and never modified again. IterDomainGraph id_graph_; @@ -239,6 +272,23 @@ class TORCH_CUDA_CU_API ComputeAtMap { IterDomain*> concrete_id_cache_; + // Unique expressions operating on exact disjoint set. For each IterDomain in + // each exact disjoint set will log its definition in the std::vector. + // If another expression is already in the set where inputs and outputs + // exactly match with the expression to add along with the other parameters of + // the transformation (like split's factor, or swizzles types) then the + // expression will not be added as it would be an "duplicate" transformation. + std::unordered_map< + std::shared_ptr>, + std::vector> + unique_exact_definitions_; + + // Same as unique_exact_definitions_ but for uses instead of definitions + std::unordered_map< + std::shared_ptr>, + std::vector> + unique_exact_uses_; + //! Allocated Loop index variable through the CA map. //! only valid for disjoint sets on the loop ca map. std::unordered_map*, Val*> diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index b40e6fbf7cf7a..481181b554c1f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -485,36 +485,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { int rhs_i = -1; int lhs_i = -1; - auto view_ops = ir_utils::getViewOps(fusion); - - /* - * If there's no path from reference through producer paths only to a view, - * e.g.: input - * / \ - * view reference - * / - * output - * - * we need to propagate the view transformations to the reference tv before - * scheduling the reference tv. Since view ops have to be identical, if any - * path from reference tv through producers goes through a view, all paths - * from reference tv's to views should be through producers. - */ - bool needs_view_prop = - view_ops.size() > 0 && - !std::any_of( - view_ops.begin(), view_ops.end(), [&reference_tv](ViewOp* view) { - return DependencyCheck::isDependencyOf(view->out(), reference_tv) || - view->out()->sameAs(reference_tv); - }); - - if (needs_view_prop) { - auto first_view_op = *view_ops.begin(); - - // Propagate the view transformations - TransformPropagator propagator(first_view_op->out()); - MaxRootDomainInfoSpanningTree spanning_tree(first_view_op->out()); - spanning_tree.traverse(&propagator); + if (ir_utils::getViewOps(fusion).size() > 0) { + ComputeAtMap ca_map(fusion); + scheduler_utils::propagateViewTransforms(fusion, ca_map); // Reorder reference_tv after propagating the view operation. This will // reorder for better merging. diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 644a4d9e7ebd8..85aec46190de4 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -446,6 +446,139 @@ bool isConnectedFusionGraph(Fusion* fusion) { return true; } +// Returns if a fusion cannot transformed into a consistent format since we +// can't transform forward through view operations, for exmaple: +// +// tv0[I0, I1, I2] +// tv1[I0*I1, I2] = view(tv0) +// tv2[I0, I1*I2] = view(tv0) +// +// If we start transform propagation at either tv1 or tv2, it would require +// "replaying forward" through the other. If we started at tv1 we'd have to be +// able to take tv2[I0, I1*I2] and transform it to [I0*I1, I2], however this +// would "undo" the view transformation which we do not support today. +// +// Returns true if a scenario like above is found in the fusion. +bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) { + // Track the uses of the rfactor domains in the fusion. If an rfactor domain + // is used in more than one way it means the above situation is being + // encountered. + // + // tv1 root: [I0rf, I1rf, I2] -> rfactor [I0*I1rf, I2] + // tv1 root: [I0, I1rf, I2rf] -> rfactor [I0, I1*I2rf] + // + // Here we can see I1rf is used in two view transformations, one to I0*I1rf, + // and the other to I1*I2rf. + + // Track the transformation each exact disjoint rfactor set is used in. If + // more than one is detected we can't support transforming the fusion into a + // consistent format. + std::unordered_map>, Expr*> + unique_exact_uses_; + + // Don't check compute uses directly, as IterDomain->uses() isn't protected + // from going outside the TensorViews between registered inputs and outputs of + // the fusion. If there are view operations defined in the fusion container + // (because of how segmentation works) but not between registered input and + // outputs, that could be picked up as inconsistent view transformations. + // + // It would be unlikely this would be picked up as a conflict as we check + // which definitions were registered in the compute at map for matching + // transformations. However, we may want to support scheduling after + // transformations which could map to those views not on the input->output + // path. + + // Look through all definitions associated with producing rfactor outputs. + // Mark those as an active use of the rfactor, if two are detected, return + // true. + for (const auto& disjoint_set_shared_ptr : + ca_map.idGraph().exactNodes().disjointSets()) { + // Make sure there's at least one rfactor domain in the set, otherwise we + // don't need to check anything from this set. + if (!std::any_of( + disjoint_set_shared_ptr->vector().begin(), + disjoint_set_shared_ptr->vector().end(), + [](IterDomain* id) { return id->isRFactorProduct(); })) { + continue; + } + + // Grab all the unique definitions detected to consume the iter domains in + // this set + auto unique_defs = + ca_map.uniqueExactDefinitions(disjoint_set_shared_ptr->back()); + + // Iterate through the all the rfactor iter domains + for (auto id_rfactor_product : disjoint_set_shared_ptr->vector()) { + if (!id_rfactor_product->isRFactorProduct()) { + continue; + } + + // Grab the rfactor definition + auto rfactor_def = id_rfactor_product->definition(); + + if (rfactor_def == nullptr) { + // Guard segfault if there isn't a definition for this iter domain + continue; + } + + // If one output of the expression is an rfactor ID all of them should be + auto def_outs = + ir_utils::filterByType(rfactor_def->outputs()); + TORCH_INTERNAL_ASSERT( + std::all_of( + def_outs.begin(), + def_outs.end(), + [](IterDomain* id) { return id->isRFactorProduct(); }), + "This function does not support outputs of transformations with mismatching rfactor flags. ", + "If one output is rfactor all should be rfactor."); + + // There could be a transformation where the inputs are rfactor + // dimensions, but outputs are not, just ignore those expressions as + // they're scheduling based after of the view definitions. + auto def_inps = ir_utils::filterByType(rfactor_def->inputs()); + if (!std::all_of(def_inps.begin(), def_inps.end(), [](IterDomain* id) { + return id->isRFactorProduct(); + })) { + continue; + } + + // Check which definition in the unique exact definition set this + // definition matches to: + for (auto unique_def : unique_defs) { + if (ca_map.areExactExprs(rfactor_def, unique_def)) { + // Check if we already have an expression that consumes an + // equivalent of any of the input rfactor domains. If so and it's + // not the already registered transformation, return false + for (auto inp : def_inps) { + auto inp_disjoint_set = + ca_map.disjointSetOf(inp, IdMappingMode::EXACT); + // Initialize the use entry for this set (if it doesn't already + // exist) + if (unique_exact_uses_.find(inp_disjoint_set) == + unique_exact_uses_.end()) { + unique_exact_uses_[inp_disjoint_set] = nullptr; + } + + if (unique_exact_uses_.at(inp_disjoint_set) == nullptr) { + // If expression is null pointer register this unique_def + unique_exact_uses_[inp_disjoint_set] = unique_def; + } else if (!ca_map.areExactExprs( + unique_exact_uses_[inp_disjoint_set], unique_def)) { + // Two transformations that don't match on matching rfactor + // domains found, return true. + return true; + } + } + // Expression already mapped, stop trying to match expressions + break; + } + } + } + } + // No inconsistent rfactor uses found, we can safely transform this graph. + return false; +} + } // namespace void SchedulerRuntimeInfo::initialize( @@ -1226,9 +1359,8 @@ class PointWiseScheduler : public SchedulerEntry { return false; } - if (!scheduler_utils::allMatchingViews(fusion) && - SchedulerTopologyChecker::hasViewNotBeforeRef( - fusion, {getReferenceTensorView(fusion)})) { + ComputeAtMap ca_map(fusion); + if (requiresForwardViewReplay(fusion, ca_map)) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::PointWise, "Unsupported view fusion."); return false; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index d985da926354b..91552818cf697 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -2206,6 +2205,7 @@ DisjointSets disjointViewSets(Fusion* fusion) { return disjoint_view_ids; } +// TODO: Remove bool allMatchingViews(Fusion* fusion) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(fusion); @@ -2391,6 +2391,88 @@ std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { return old2new; } +void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { + std::unordered_set>> + transformed_disjoint_sets; + + // If iter domains are involved in any transformation from root domains to + // rfactor domains they should be considered "contaminated". + for (auto tv : ir_utils::allTvs(fusion)) { + for (auto expr : StmtSort::getExprsBetween( + fusion, + {tv->getRootDomain().begin(), tv->getRootDomain().end()}, + {tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end()})) { + for (auto id : ir_utils::filterByType(expr->inputs())) { + transformed_disjoint_sets.emplace( + ca_map.disjointSetOf(id, IdMappingMode::EXACT)); + } + } + } + + std::unordered_set terminating_rfactor_dims; + for (const auto& disjoint_set_shared_ptr : + ca_map.idGraph().exactNodes().disjointSets()) { + if (std::none_of( + disjoint_set_shared_ptr->vector().begin(), + disjoint_set_shared_ptr->vector().end(), + [](IterDomain* id) { return id->isRFactorProduct(); })) { + continue; + } + if (transformed_disjoint_sets.find(disjoint_set_shared_ptr) != + transformed_disjoint_sets.end()) { + // Disjoint set was transformed for view, ignore it + continue; + } + for (auto id : disjoint_set_shared_ptr->vector()) { + terminating_rfactor_dims.emplace(id); + } + } + + // If iter domains are involved in any transformation from root domains to + // rfactor domains they should be considered "contaminated". + for (auto tv : ir_utils::allTvs(fusion)) { + if (!tv->hasRFactor()) { + continue; + } + + std::unordered_map old2new; + // Make sure rfactor dims we need are in domain, and reorder them in domain + // so they're consecutive starting from the left of domain. TODO: We could + // improve this so that if there's transformations replayed after the + // rfactor dims we could try and pull those through the fusion instead of + // enforcing rfactor dims are in domain. + for (auto rfactor_id : tv->getMaybeRFactorDomain()) { + if (terminating_rfactor_dims.find(rfactor_id) != + terminating_rfactor_dims.end()) { + auto find_it = std::find( + tv->domain()->domain().begin(), + tv->domain()->domain().end(), + rfactor_id); + TORCH_INTERNAL_ASSERT( + find_it != tv->domain()->domain().end(), + "Require ", + rfactor_id, + " is in the active domain of ", + tv->toString(), + " for view propagation."); + auto old_pos = std::distance(tv->domain()->domain().begin(), find_it); + + old2new[old_pos] = old2new.size(); + } + } + + if (old2new.empty()) { + continue; + } + + // Propagate the view transformations + tv->reorder(old2new); + //! Propagate current transformations on from_tv to all graphs + transformPropagateToAllFrom(tv, old2new.size()); + } +} + } // namespace scheduler_utils } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 373a879f740d5..7d83cfa268875 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -536,6 +537,10 @@ TORCH_CUDA_CU_API bool breakIsDisjoint(std::vector group_ids, int pos); TORCH_CUDA_CU_API std::unordered_map domainReorderAsRfactorMap( TensorView* tv); +// Assumes view's are consistent as detected by +// registery.cpp::requiresForwardViewReplay returning false +void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map); + } // namespace scheduler_utils } // namespace cuda } // namespace fuser From 31a63fca001d1dc81f424d273509ae8f372f20b5 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Tue, 11 Oct 2022 17:04:37 -0400 Subject: [PATCH 2/2] Almost, but hitting some issues. --- .../codegen/cuda/scheduler/normalization.cpp | 44 +++- .../jit/codegen/cuda/scheduler/pointwise.cpp | 1 + .../jit/codegen/cuda/scheduler/reduction.cpp | 11 + .../jit/codegen/cuda/scheduler/registry.cpp | 210 ++++++++++++++---- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 21 -- torch/csrc/jit/codegen/cuda/scheduler/utils.h | 35 +++ .../jit/codegen/cuda/test/test_gpu_view.cpp | 94 ++++++++ 7 files changed, 351 insertions(+), 65 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp index 459974b8d2884..9b7f19c243150 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp @@ -768,6 +768,18 @@ std::shared_ptr persistentHeuristic( const int64_t max_persistent_buffer_size, size_t vectorize_factor, bool project_persistent_buffers) { + std::cout << "Within heuristics" << std::endl; + std::cout << "\n" + << total_reduction_numel << "\n" // 945 + << total_iteration_numel << "\n" // 42 + << inner_most_dimension_numel << "\n" // 63 + << fastest_dim_reduction << "\n" // 1 + << n_tensor_inputs << "\n" // 2 + << max_input_dtype_size << "\n" // 4 + << max_persistent_buffer_size << "\n" // 0 + << vectorize_factor << "\n" // 1 + << project_persistent_buffers << std::endl; // 1 + std::shared_ptr rparams; if (fastest_dim_reduction) { rparams = innerPersistentHeuristic( @@ -788,6 +800,7 @@ std::shared_ptr persistentHeuristic( vectorize_factor); } rparams->project_persistent_buffers = project_persistent_buffers; +std::cout<<"Finished within heuristics"< getPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, HeuristicSummary* data_cache) { + std::cout<<"???"< getPersistentHeuristics( scheduler_utils::getProperties(fusion, runtime_info, first_red_tv); // Grab persistent buffer sizes + fusion->printMath(); auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( fusion, runtime_info, persistent_buffer_info, data_cache); // If projected persistent buffers are smaller, they will be used. - auto max_persistent_size = std::min( - persistent_buffer_size_info.persistent_buffer_size, - persistent_buffer_size_info.projected_persistent_buffer_size); + // TODO: projected buffers coming up as 0 in NVFuserTest.FusionViewMagicSchedule7_CUDA + auto max_persistent_size = ir_utils::getViewOps(fusion).size() > 0 + ? persistent_buffer_size_info.persistent_buffer_size + : std::min( + persistent_buffer_size_info.persistent_buffer_size, + persistent_buffer_size_info.projected_persistent_buffer_size); + + std::cout << persistent_buffer_size_info.persistent_buffer_size << " :: " + << persistent_buffer_size_info.projected_persistent_buffer_size + << std::endl; // Figure out if we want to projet persistent buffers to the inputs for // exmaple if we have an input tensor t0 that's fp16: @@ -908,6 +930,7 @@ TORCH_CUDA_CU_API std::shared_ptr getPersistentHeuristics( vectorize_factor = 1; } +std::cout<<"A"< getPersistentHeuristics( (int)(first_red_tv->nDims() - properties.inner_most_dimension_ndims), vectorize_factor); +std::cout<<"B"< 0) { + ComputeAtMap ca_map(fusion); + // Propagate view transforms through the graph, expecially the reference. + scheduler_utils::propagateViewTransforms(fusion, ca_map); + + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + reduction_tv->reorder( + scheduler_utils::domainReorderAsRfactorMap(reduction_tv)); + fusion->printMath(); + } + auto dim_analysis = scheduler_utils::canonicalDimReduction( fusion, reduction_tv, rparams.fastest_dim && rparams.schedule_3D); bool has_iter_axis = dim_analysis.first; @@ -1028,6 +1064,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel( reduction_tvs, cached_inputs, cached_outputs); + +fusion->printMath(); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp index 481181b554c1f..de5f1da14fae1 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp @@ -487,6 +487,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) { if (ir_utils::getViewOps(fusion).size() > 0) { ComputeAtMap ca_map(fusion); + // Propagate view transforms through the graph, expecially the reference. scheduler_utils::propagateViewTransforms(fusion, ca_map); // Reorder reference_tv after propagating the view operation. This will diff --git a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp index 3037f8469dad4..80f396c2a0128 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp @@ -1014,6 +1014,17 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) { // changes registry needs to change. auto reduction_tv = reduction_tvs[0]; + if (ir_utils::getViewOps(fusion).size() > 0) { + ComputeAtMap ca_map(fusion); + // Propagate view transforms through the graph, expecially the reference. + scheduler_utils::propagateViewTransforms(fusion, ca_map); + + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + reduction_tv->reorder( + scheduler_utils::domainReorderAsRfactorMap(reduction_tv)); + } + auto dim_analysis = scheduler_utils::canonicalDimReduction( fusion, reduction_tv, rparams.fastest_dim && rparams.schedule_3D); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 85aec46190de4..9817ada6eaa8a 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -579,6 +579,128 @@ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) { return false; } +// Returns if view intefers with how we want to treat the reference, being at +// least a 2D reduction schedule but maybe a 3D reduction schedule. +bool reductionInterferingView( + Fusion* fusion, + const ComputeAtMap& ca_map, + TensorView* reduction_reference) { + // Make sure the view doesn't interfere with how we'll want to schedule + // it. If we might want to do a 3D scheduler make sure views are disjoint + // based on what the 3D scheduler's merges would be. + + // Utility to take dimensions out of the vector that we've already + // processed or don't want to process. + auto remove_dims = + [](const std::vector& dims, + std::unordered_set to_remove) { + std::vector dims_removed; + std::copy_if( + dims.begin(), + dims.end(), + std::back_inserter(dims_removed), + [&](IterDomain* id) { + return to_remove.find(id) == to_remove.end(); + }); + return dims_removed; + }; + + // Remove trivial reduction dimensions + auto mapped_to_trivial_reduction = scheduler_utils::getTrivialReductionMap(fusion); + + std::vector dims = remove_dims( + reduction_reference->getMaybeRFactorDomain(), mapped_to_trivial_reduction); + + // The disjoint groups we need for this scheduler + std::vector> groups; + + // Do this three times as we could have a 3D scheduler at maximum + for (auto dimension : c10::irange(3)) { + // Tracker for this group + std::vector current_dims; + + // Tracker of what we've already processed to remove from dims + std::unordered_set processed; + + for (auto i : c10::irange(dims.size())) { + auto dim_i = dims.size() - i - 1; + if (dims[dim_i]->isReduction() != dims[dims.size() - 1]->isReduction()) { + if (dimension == 0) { + // First dimension must be contiguous merges + break; + } else { + // Other dimensions can be non contiguous merges + continue; + } + } + current_dims.push_back(dims[dim_i]); + processed.emplace(dims[dim_i]); + } + + // Don't add empty group (would happen if it's a 2D scheduler not 3D) + if (current_dims.size() > 0) { + groups.push_back(current_dims); + dims = remove_dims(dims, processed); + } + } + + TORCH_INTERNAL_ASSERT( + dims.empty(), "Error processing ", dims, " in registry.cpp."); + + // Make sure groups are disjoint based on view + + auto disjoint_view_sets = scheduler_utils::disjointViewSets(fusion); + auto disjoint_set_information = scheduler_utils::getDisjointViewSetsOf( + fusion, reduction_reference, disjoint_view_sets); + + // Convert id's in groups to disjoint_set_ids of disjoint_set_information + std::vector> disjoint_groups; + + for (auto group : groups) { + std::vector disjoint_id_sets; + for (auto id : group) { + auto find_it = std::find( + reduction_reference->getMaybeRFactorDomain().begin(), + reduction_reference->getMaybeRFactorDomain().end(), + id); + TORCH_INTERNAL_ASSERT( + find_it != reduction_reference->getMaybeRFactorDomain().end(), + "Issue with view analysis on reduction like schedule, with reference: ", + reduction_reference->toString()); + auto rfactor_pos = std::distance( + reduction_reference->getMaybeRFactorDomain().begin(), find_it); + TORCH_INTERNAL_ASSERT( + rfactor_pos < disjoint_set_information.disjoint_set_ids.size(), + "Error computing disjoint group on the rfactor domain of ", + reduction_reference->toString()); + disjoint_id_sets.push_back( + disjoint_set_information.disjoint_set_ids[rfactor_pos]); + } + disjoint_groups.push_back(disjoint_id_sets); + } + + // Make sure there's no intersection between the groups, otherwise view + // will interfere with the schedule. TODO: Make this better complexity, + // since it should be relatively small int vectors of a small total nDims, + // not too worried about it now. + + for (auto first_dim_i : c10::irange(disjoint_groups.size())) { + for (auto second_dim_i = first_dim_i + 1; second_dim_i < disjoint_groups.size(); + ++second_dim_i) { + auto first_group = disjoint_groups[first_dim_i]; + auto second_group = disjoint_groups[second_dim_i]; + for (auto first_disjoint_id : first_group) { + for (auto second_disjoint_id : second_group) { + if (first_disjoint_id == second_disjoint_id) { + return true; + } + } + } + } + } + return false; +} + } // namespace void SchedulerRuntimeInfo::initialize( @@ -1113,15 +1235,6 @@ class ReductionScheduler : public SchedulerEntry { //! Check if the reduction heuristics apply in given fusion static bool canScheduleCompileTime(Fusion* fusion) { - // Temporarily disallow view in reduction scheduler - // TODO Add more testing before enabling - auto view_tvs = scheduler_utils::getViewTVs(fusion); - if (view_tvs.size() > 0) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Reduction, "No support for view op"); - return false; - } - // Needs at least one non-trivial reduction to consider. if (ir_utils::getReductionOps(fusion, true /* ignore_trivial */).empty()) { scheduler_debug_utils::canScheduleRejectReason( @@ -1151,15 +1264,24 @@ class ReductionScheduler : public SchedulerEntry { return false; } - // Persistent scheduler simply uses reduction_tvs[0] as the reference, if - // that changes, this needs to be changed. Second check here may be overly - // conservative. - if (SchedulerTopologyChecker::hasViewNotBeforeRef( - fusion, {reduction_tvs[0]}) || - !scheduler_utils::allMatchingViews(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Reduction, "Unsupported view fusion."); - return false; + if (ir_utils::getViewOps(fusion).size() > 0) { + ComputeAtMap ca_map(fusion); + if (requiresForwardViewReplay(fusion, ca_map)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Reduction, + "Fusion requires view being reversible."); + return false; + } + + // Reduction scheduler simply uses reduction_tvs[0] as the reference, if + // that changes, this needs to be changed. + if (reductionInterferingView( + fusion, ca_map, reduction_tvs[0])) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Reduction, + "View may interfere with reduction scheduling."); + return false; + } } // Make sure reduction axes are consistent through the fusion @@ -1359,11 +1481,13 @@ class PointWiseScheduler : public SchedulerEntry { return false; } - ComputeAtMap ca_map(fusion); - if (requiresForwardViewReplay(fusion, ca_map)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::PointWise, "Unsupported view fusion."); - return false; + if (ir_utils::getViewOps(fusion).size() > 0) { + ComputeAtMap ca_map(fusion); + if (requiresForwardViewReplay(fusion, ca_map)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::PointWise, "Fusion requires view being reversible."); + return false; + } } auto reduction_ops = @@ -1444,13 +1568,6 @@ class PersistentKernelScheduler : public SchedulerEntry { auto reduction_ops = ir_utils::getReductionOps(fusion, false /* ignore_trivial */); - auto view_tvs = scheduler_utils::getViewTVs(fusion); - if (view_tvs.size() > 0) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Persistent, "no support for view"); - return false; - } - if (hasNonUniqueBcast(fusion)) { scheduler_debug_utils::canScheduleRejectReason( ScheduleHeuristic::Persistent, @@ -1468,14 +1585,23 @@ class PersistentKernelScheduler : public SchedulerEntry { return false; } - // Persistent scheduler simply uses reduction_tvs[0] as the reference, if - // that changes, this needs to be changed. Second check here may be overly - // conservative. - if (SchedulerTopologyChecker::hasViewNotBeforeRef( - fusion, {reduction_tvs[0]}) || - !scheduler_utils::allMatchingViews(fusion)) { - scheduler_debug_utils::canScheduleRejectReason( - ScheduleHeuristic::Persistent, "Unsupported view fusion."); + if (ir_utils::getViewOps(fusion).size() > 0) { + ComputeAtMap ca_map(fusion); + if (requiresForwardViewReplay(fusion, ca_map)) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "Fusion requires view being reversible."); + return false; + } + + // Persistent scheduler simply uses reduction_tvs[0] as the reference, if + // that changes, this needs to be changed. + if (reductionInterferingView(fusion, ca_map, reduction_tvs[0])) { + scheduler_debug_utils::canScheduleRejectReason( + ScheduleHeuristic::Persistent, + "View may interfere with normalization scheduling."); + return false; + } } if (findTransposeOps(fusion).size() > 0) { @@ -1547,7 +1673,7 @@ class PersistentKernelScheduler : public SchedulerEntry { "unsupported post reduction normalization"); return false; } - +std::cout<<"End compile time"<( data_cache, [&fusion]() { @@ -1633,7 +1759,7 @@ class PersistentKernelScheduler : public SchedulerEntry { return false; } - +std::cout<<"End runtime check"< getInputsOutputsWithInnerDim( return vectorizable_tensors; } -namespace { -// Holder return struct for the below function. -struct DisjointViewSetInfo { - // const* to the disjoint set in disjoint_view_set passed in to - // getDisjointViewSetsOf each iterdomain in the rfactor of ref is mapped to. - // - // WARNING: these pointers are relative to the disjoint_view_set reference - // passed into getDisjointViewSetsOf it's the user's responsibillity to - // maintain the lifetime of that reference to match this vector. - std::vector*> disjoint_sets_of_ref; - - // Unique ID associated to the disjoint view group the rfactor id belongs to - // in disjoint_sets_of_ref. It's straight forward to map from - // disjoint_sets_of_ref to the vector, but not the other way around. - std::vector disjoint_set_ids; - - // TensorView reference the above vectors are relative to. - TensorView* ref; -}; - // Returns disjoint view sets mapped onto the given reference. Returns a pair // of vectors of size rfactorDomain of reference. Vector of // VectorOfUniqueEntries returns a const* to the disjoint set in @@ -1495,7 +1475,6 @@ DisjointViewSetInfo getDisjointViewSetsOf( return info; } -} // namespace BroadcastMultipleInformation getBroadcastMultiples( TensorView* reference_tv, diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.h b/torch/csrc/jit/codegen/cuda/scheduler/utils.h index 7d83cfa268875..ac234209f9441 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.h @@ -290,6 +290,41 @@ std::vector getInputsOutputsWithInnerDim( bool inner_only, bool vectorize_pass); +// Holder return struct for the below function. +struct DisjointViewSetInfo { + // const* to the disjoint set in disjoint_view_set passed in to + // getDisjointViewSetsOf each iterdomain in the rfactor of ref is mapped to. + // + // WARNING: these pointers are relative to the disjoint_view_set reference + // passed into getDisjointViewSetsOf it's the user's responsibillity to + // maintain the lifetime of that reference to match this vector. + std::vector*> disjoint_sets_of_ref; + + // Unique ID associated to the disjoint view group the rfactor id belongs to + // in disjoint_sets_of_ref. It's straight forward to map from + // disjoint_sets_of_ref to the vector, but not the other way around. + std::vector disjoint_set_ids; + + // TensorView reference the above vectors are relative to. + TensorView* ref; +}; + +// Returns disjoint view sets mapped onto the given reference. Returns a pair +// of vectors of size rfactorDomain of reference. Vector of +// VectorOfUniqueEntries returns a const* to the disjoint set in +// disjoint_view_set the iterdomain is mapped to. Integer vector represents +// which disjoint view group the rfactor id belongs to. It's straight forward +// to map from the former to the latter, but not the latter to former. +// +// Since we return a const* to entries in disjoint_view_set, it must be passed +// in as a reference. Algorithm is N^2 based on number of dims in reference, +// but generating the disjoint view set is likely the limiter on perf of this +// function. +DisjointViewSetInfo getDisjointViewSetsOf( + Fusion* fusion, + TensorView* of, + DisjointSets& disjoint_view_set); + // Structure to hold byte multiples for break points. I.e. if we have the // tensors: // T0[I0, I1] float diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 1ed73d3256bcf..bb30c12a9b05c 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -1724,6 +1724,100 @@ TEST_F(NVFuserTest, FusionViewMagicSchedule5_CUDA) { testValidate(&fusion, cg_outputs, {t0, t3}, {t6}, __LINE__, __FILE__); } +// View with 3D reduction scheduling +TEST_F(NVFuserTest, FusionViewMagicSchedule6_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int v = 3, w = 5, x = 42, y = 7, z = 9; + + auto tv0 = makeConcreteTensor({w, v, x, y, z}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = view(tv1, {w, v, x, y, z}, {v*w, x, y*z}); + + auto tv3 = makeConcreteTensor({v, w, x, z, y}); + fusion.addInput(tv3); + auto tv4 = cos(tv3); + auto tv5 = view(tv4, {v, w, x, z, y}, {v*w, x, y*z}); + + auto tv6 = add(tv2, tv5); + auto tv7 = sum(tv6, {0, 2}); + fusion.addOutput(tv7); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({w, v, x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {v*w, x, y*z}); + at::Tensor t3 = at::randn({v, w, x, z, y}, options); + auto t4 = cos(t3); + auto t5 = at::native::view(t4, {v*w, x, y*z}); + auto t7 = add(t2, t5).sum(2).sum(0); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // Collect the heuristic params + executor_cache.profile(true); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3}); + + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + TORCH_CHECK(executor_cache.getMostRecentExecutorInfo() + .params->isA()); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t7}, __LINE__, __FILE__); +} + +// View with 3D normalization scheduling +TEST_F(NVFuserTest, FusionViewMagicSchedule7_CUDA) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + int v = 3, w = 5, x = 42, y = 7, z = 9; + + auto tv0 = makeConcreteTensor({w, v, x, y, z}); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = view(tv1, {w, v, x, y, z}, {v*w, x, y*z}); + + auto tv3 = makeConcreteTensor({v, w, x, z, y}); + fusion.addInput(tv3); + auto tv4 = cos(tv3); + auto tv5 = view(tv4, {v, w, x, z, y}, {v*w, x, y*z}); + + auto tv6 = add(tv2, tv5); + auto tv7 = sum(tv6, {0, 2}); + auto tv8 = broadcast(tv7, {true, false, true}); + auto tv9 = add(tv6, tv8); + fusion.addOutput(tv9); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({w, v, x, y, z}, options); + auto t1 = sin(t0); + auto t2 = at::native::view(t1, {v*w, x, y*z}); + // This might trigger transpose kernel. + at::Tensor t3 = at::randn({v, w, x, z, y}, options); + auto t4 = cos(t3); + auto t5 = at::native::view(t4, {v*w, x, y*z}); + auto t6 = add(t2, t5); + auto t7 = t6.sum(2).sum(0); + auto t8 = t7.unsqueeze(-1).unsqueeze(0); + auto t9 = t6 + t8; + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + // Collect the heuristic params + executor_cache.profile(true); + auto cg_outputs = executor_cache.runFusionWithInputs({t0, t3}); + + TORCH_CHECK(!executor_cache.getMostRecentKernelRuntime()->isSegmented()); + TORCH_CHECK(executor_cache.getMostRecentExecutorInfo() + .params->isA()); + + testValidate(&fusion, cg_outputs, {t0, t3}, {t9}, __LINE__, __FILE__); +} + // Make sure different views that are consumed by the reference are segmented // into a single kernel. TEST_F(NVFuserTest, FusionViewMapping_CUDA) {