From a358d8de1cc34000d80701bfb8e78cbbc759adde Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 18 Mar 2021 12:26:42 -0700 Subject: [PATCH] Move back the previous version of getAllValsBetween The new version was introduced at PR #729. It is now renamed to getAllValsBetween2. Three Python tests are failing with the new version, so temporarily move back to the previous version. --- test/cpp/jit/test_gpu.cpp | 11 +++-- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 48 +++++++++++++++++-- torch/csrc/jit/codegen/cuda/iter_visitor.h | 8 +++- .../jit/codegen/cuda/scheduler_registry.cpp | 0 4 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 torch/csrc/jit/codegen/cuda/scheduler_registry.cpp diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 0375d3e49cef..edff04b20112 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -13899,7 +13899,7 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { // tv2 -> tv6 auto all_vals_under_tv3 = - DependencyCheck::getAllValsBetween({tv3}, fusion.outputs()); + DependencyCheck::getAllValsBetween2({tv3}, fusion.outputs()); std::unordered_set included_tensors({tv3, tv4, tv5}); for (auto tv : included_tensors) { TORCH_CHECK( @@ -13920,16 +13920,17 @@ TEST(NVFuserTest, FusionIssue728_CUDA) { } } - auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs()); + auto no_dependency = + DependencyCheck::getAllValsBetween2({}, fusion.outputs()); TORCH_CHECK(no_dependency.empty(), "No val should be returned"); - auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6}); + auto no_dep_path = DependencyCheck::getAllValsBetween2({tv0, tv1}, {tv6}); TORCH_CHECK(no_dep_path.empty(), "No val should be returned"); - auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5}); + auto no_dep_path2 = DependencyCheck::getAllValsBetween2({tv2}, {tv5}); TORCH_CHECK(no_dep_path2.empty(), "No val should be returned"); - auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3}); + auto just_tv3 = DependencyCheck::getAllValsBetween2({tv3}, {tv3}); TORCH_CHECK( just_tv3.size() == 1 && *(just_tv3.begin()) == tv3, "Only tv3 should be included"); diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 0ff70445f0bb..931c021ddc2d 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -364,6 +364,42 @@ namespace { // Looks for and returns all values in between dependencies and vals, including // them. struct Dependencies : public IterVisitor { + std::unordered_set dependencies_; + std::unordered_set vals_; + + std::vector next(Val* v) override { + if (dependencies_.find(v) != dependencies_.end()) + return std::vector(); + return IterVisitor::next(v); + } + + void handle(Val* val) override { + vals_.emplace(val); + } + + Dependencies( + std::unordered_set _dependencies, + const std::vector& of) + : dependencies_(std::move(_dependencies)) { + traverseFrom(of[0]->fusion(), of, false); + }; + + public: + static std::unordered_set getAllVals( + const std::unordered_set& dependencies, + const std::vector& of) { + if (of.empty()) { + return std::unordered_set(); + } + + Dependencies deps(dependencies, of); + return deps.vals_; + } +}; + +// Looks for and returns all values in between dependencies and vals, including +// them. +struct Dependencies2 : public IterVisitor { private: //! A given set of dependency Vals const std::unordered_set dependencies_; @@ -410,7 +446,7 @@ struct Dependencies : public IterVisitor { } } - Dependencies( + Dependencies2( std::unordered_set _dependencies, const std::vector& of) : dependencies_(std::move(_dependencies)) { @@ -425,7 +461,7 @@ struct Dependencies : public IterVisitor { return {}; } - Dependencies deps(dependencies, of); + Dependencies2 deps(dependencies, of); return deps.vals_; } }; @@ -631,12 +667,18 @@ std::deque> DependencyCheck::getAllUseChains(Val* producer) { return DependencyChains::getAllUseChains(producer); } -std::vector DependencyCheck::getAllValsBetween( +std::unordered_set DependencyCheck::getAllValsBetween( const std::unordered_set& dependencies, const std::vector& of) { return Dependencies::getAllVals(dependencies, of); } +std::vector DependencyCheck::getAllValsBetween2( + const std::unordered_set& dependencies, + const std::vector& of) { + return Dependencies2::getAllVals(dependencies, of); +} + std::unordered_set DependencyCheck::getAllOutputsOf( const std::unordered_set& of) { if (of.empty()) { diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 490b5b4179ea..4fd0984b49c0 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -199,7 +199,13 @@ class TORCH_CUDA_CU_API DependencyCheck { // Grab all values that exist between and including provided // vals. Returned values are topologicaly ordered. - static std::vector getAllValsBetween( + static std::unordered_set getAllValsBetween( + const std::unordered_set& dependencies, + const std::vector& of); + + // Grab all values that exist between and including provided + // vals. Returned values are topologicaly ordered. + static std::vector getAllValsBetween2( const std::unordered_set& dependencies, const std::vector& of); diff --git a/torch/csrc/jit/codegen/cuda/scheduler_registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler_registry.cpp new file mode 100644 index 000000000000..e69de29bb2d1