diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 0375d3e49cef..edff04b20112 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -13899,7 +13899,7 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { // tv2 -> tv6 auto all_vals_under_tv3 = - DependencyCheck::getAllValsBetween({tv3}, fusion.outputs()); + DependencyCheck::getAllValsBetween2({tv3}, fusion.outputs()); std::unordered_set included_tensors({tv3, tv4, tv5}); for (auto tv : included_tensors) { TORCH_CHECK( @@ -13920,16 +13920,17 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { } } - auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs()); + auto no_dependency = + DependencyCheck::getAllValsBetween2({}, fusion.outputs()); TORCH_CHECK(no_dependency.empty(), "No val should be returned"); - auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6}); + auto no_dep_path = DependencyCheck::getAllValsBetween2({tv0, tv1}, {tv6}); TORCH_CHECK(no_dep_path.empty(), "No val should be returned"); - auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5}); + auto no_dep_path2 = DependencyCheck::getAllValsBetween2({tv2}, {tv5}); TORCH_CHECK(no_dep_path2.empty(), "No val should be returned"); - auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3}); + auto just_tv3 = DependencyCheck::getAllValsBetween2({tv3}, {tv3}); TORCH_CHECK( just_tv3.size() == 1 && *(just_tv3.begin()) == tv3, "Only tv3 should be included"); diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 0ff70445f0bb..931c021ddc2d 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -364,6 +364,42 @@ namespace { // Looks for and returns all values in between dependencies and vals, including // them. struct Dependencies : public IterVisitor { + std::unordered_set dependencies_; + std::unordered_set vals_; + + std::vector next(Val* v) override { + if (dependencies_.find(v) != dependencies_.end()) + return std::vector(); + return IterVisitor::next(v); + } + + void handle(Val* val) override { + vals_.emplace(val); + } + + Dependencies( + std::unordered_set _dependencies, + const std::vector& of) + : dependencies_(std::move(_dependencies)) { + traverseFrom(of[0]->fusion(), of, false); + }; + + public: + static std::unordered_set getAllVals( + const std::unordered_set& dependencies, + const std::vector& of) { + if (of.empty()) { + return std::unordered_set(); + } + + Dependencies deps(dependencies, of); + return deps.vals_; + } +}; + +// Looks for and returns all values in between dependencies and vals, including +// them. +struct Dependencies2 : public IterVisitor { private: //! A given set of dependency Vals const std::unordered_set dependencies_; @@ -410,7 +446,7 @@ struct Dependencies : public IterVisitor { } } - Dependencies( + Dependencies2( std::unordered_set _dependencies, const std::vector& of) : dependencies_(std::move(_dependencies)) { @@ -425,7 +461,7 @@ struct Dependencies : public IterVisitor { return {}; } - Dependencies deps(dependencies, of); + Dependencies2 deps(dependencies, of); return deps.vals_; } }; @@ -631,12 +667,18 @@ std::deque> DependencyCheck::getAllUseChains(Val* producer) { return DependencyChains::getAllUseChains(producer); } -std::vector DependencyCheck::getAllValsBetween( +std::unordered_set DependencyCheck::getAllValsBetween( const std::unordered_set& dependencies, const std::vector& of) { return Dependencies::getAllVals(dependencies, of); } +std::vector DependencyCheck::getAllValsBetween2( + const std::unordered_set& dependencies, + const std::vector& of) { + return Dependencies2::getAllVals(dependencies, of); +} + std::unordered_set DependencyCheck::getAllOutputsOf( const std::unordered_set& of) { if (of.empty()) { diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 490b5b4179ea..4fd0984b49c0 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -199,7 +199,13 @@ class TORCH_CUDA_CU_API DependencyCheck { // Grab all values that exist between and including provided // vals. Returned values are topologicaly ordered. - static std::vector getAllValsBetween( + static std::unordered_set getAllValsBetween( + const std::unordered_set& dependencies, + const std::vector& of); + + // Grab all values that exist between and including provided + // vals. Returned values are topologicaly ordered. + static std::vector getAllValsBetween2( const std::unordered_set& dependencies, const std::vector& of); diff --git a/torch/csrc/jit/codegen/cuda/scheduler_registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler_registry.cpp new file mode 100644 index 000000000000..e69de29bb2d1