Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13899,7 +13899,7 @@ TEST(NVFuserTest, FusionIssue728_CUDA) {
// tv2 -> tv6

auto all_vals_under_tv3 =
DependencyCheck::getAllValsBetween({tv3}, fusion.outputs());
DependencyCheck::getAllValsBetween2({tv3}, fusion.outputs());
std::unordered_set<Val*> included_tensors({tv3, tv4, tv5});
for (auto tv : included_tensors) {
TORCH_CHECK(
Expand All @@ -13920,16 +13920,17 @@ TEST(NVFuserTest, FusionIssue728_CUDA) {
}
}

auto no_dependency = DependencyCheck::getAllValsBetween({}, fusion.outputs());
auto no_dependency =
DependencyCheck::getAllValsBetween2({}, fusion.outputs());
TORCH_CHECK(no_dependency.empty(), "No val should be returned");

auto no_dep_path = DependencyCheck::getAllValsBetween({tv0, tv1}, {tv6});
auto no_dep_path = DependencyCheck::getAllValsBetween2({tv0, tv1}, {tv6});
TORCH_CHECK(no_dep_path.empty(), "No val should be returned");

auto no_dep_path2 = DependencyCheck::getAllValsBetween({tv2}, {tv5});
auto no_dep_path2 = DependencyCheck::getAllValsBetween2({tv2}, {tv5});
TORCH_CHECK(no_dep_path2.empty(), "No val should be returned");

auto just_tv3 = DependencyCheck::getAllValsBetween({tv3}, {tv3});
auto just_tv3 = DependencyCheck::getAllValsBetween2({tv3}, {tv3});
TORCH_CHECK(
just_tv3.size() == 1 && *(just_tv3.begin()) == tv3,
"Only tv3 should be included");
Expand Down
48 changes: 45 additions & 3 deletions torch/csrc/jit/codegen/cuda/iter_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,42 @@ namespace {
// Looks for and returns all values in between dependencies and vals, including
// them.
struct Dependencies : public IterVisitor {
std::unordered_set<Val*> dependencies_;
std::unordered_set<Val*> vals_;

std::vector<Statement*> next(Val* v) override {
if (dependencies_.find(v) != dependencies_.end())
return std::vector<Statement*>();
return IterVisitor::next(v);
}

void handle(Val* val) override {
vals_.emplace(val);
}

Dependencies(
std::unordered_set<Val*> _dependencies,
const std::vector<Val*>& of)
: dependencies_(std::move(_dependencies)) {
traverseFrom(of[0]->fusion(), of, false);
};

public:
static std::unordered_set<Val*> getAllVals(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of) {
if (of.empty()) {
return std::unordered_set<Val*>();
}

Dependencies deps(dependencies, of);
return deps.vals_;
}
};

// Looks for and returns all values in between dependencies and vals, including
// them.
struct Dependencies2 : public IterVisitor {
private:
//! A given set of dependency Vals
const std::unordered_set<Val*> dependencies_;
Expand Down Expand Up @@ -410,7 +446,7 @@ struct Dependencies : public IterVisitor {
}
}

Dependencies(
Dependencies2(
std::unordered_set<Val*> _dependencies,
const std::vector<Val*>& of)
: dependencies_(std::move(_dependencies)) {
Expand All @@ -425,7 +461,7 @@ struct Dependencies : public IterVisitor {
return {};
}

Dependencies deps(dependencies, of);
Dependencies2 deps(dependencies, of);
return deps.vals_;
}
};
Expand Down Expand Up @@ -631,12 +667,18 @@ std::deque<std::deque<Val*>> DependencyCheck::getAllUseChains(Val* producer) {
return DependencyChains::getAllUseChains(producer);
}

std::vector<Val*> DependencyCheck::getAllValsBetween(
std::unordered_set<Val*> DependencyCheck::getAllValsBetween(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of) {
return Dependencies::getAllVals(dependencies, of);
}

std::vector<Val*> DependencyCheck::getAllValsBetween2(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of) {
return Dependencies2::getAllVals(dependencies, of);
}

std::unordered_set<Val*> DependencyCheck::getAllOutputsOf(
const std::unordered_set<Val*>& of) {
if (of.empty()) {
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,13 @@ class TORCH_CUDA_CU_API DependencyCheck {

// Grab all values that exist between and including provided
// vals. Returned values are topologicaly ordered.
static std::vector<Val*> getAllValsBetween(
static std::unordered_set<Val*> getAllValsBetween(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of);

// Grab all values that exist between and including provided
// vals. Returned values are topologicaly ordered.
static std::vector<Val*> getAllValsBetween2(
const std::unordered_set<Val*>& dependencies,
const std::vector<Val*>& of);

Expand Down
Empty file.