Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,73 @@ TEST(NVFuserTest, KernelExprEvalBindings_CUDA) {
checkIntValue(evaluator, d, -2);
}

// Test name-to-node lookup in the Fusion IR
TEST(NVFuserTest, FusionValueLookup_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);

auto scalar = new Double(-1.0);
auto tv1 = mul(tv0, scalar);
auto tv2 = add(tv0, new Double(3.0));
auto tv3 = mul(tv0, new Double(2.0));
auto tv4 = add(tv2, tv1);
auto tv5 = add(tv4, tv3);
auto tv6 = add(tv0, tv3);

fusion.addOutput(tv5);
fusion.addOutput(tv6);

// using the value's val type
ASSERT_EQ(fusion.lookupValue(*tv0->getValType(), tv0->name()), tv0);
ASSERT_EQ(fusion.lookupValue(*scalar->getValType(), scalar->name()), scalar);

// explicit ValType
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv1->name()), tv1);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv2->name()), tv2);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv3->name()), tv3);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv4->name()), tv4);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv5->name()), tv5);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv6->name()), tv6);

// misses
ASSERT_NE(fusion.lookupValue(ValType::Scalar, tv0->name()), tv0);
ASSERT_NE(fusion.lookupValue(ValType::TensorView, tv1->name()), tv0);

// non-existent names
ASSERT_EQ(fusion.lookupValue(ValType::Scalar, 12345), nullptr);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, 12345), nullptr);

Fusion copy(fusion);

auto copy_tv1 = copy.lookupValue(ValType::TensorView, tv1->name());
auto copy_tv2 = copy.lookupValue(ValType::TensorView, tv2->name());
auto copy_tv3 = copy.lookupValue(ValType::TensorView, tv3->name());
auto copy_tv4 = copy.lookupValue(ValType::TensorView, tv4->name());
auto copy_tv5 = copy.lookupValue(ValType::TensorView, tv5->name());
auto copy_tv6 = copy.lookupValue(ValType::TensorView, tv6->name());

swap(fusion, copy);

ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv1->name()), copy_tv1);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv2->name()), copy_tv2);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv3->name()), copy_tv3);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv4->name()), copy_tv4);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv5->name()), copy_tv5);
ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv6->name()), copy_tv6);

fusion.clear();

ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv1->name()), tv1);
ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv2->name()), tv2);
ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv3->name()), tv3);
ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv4->name()), tv4);
ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv5->name()), tv5);
ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv6->name()), tv6);
}

TEST(NVFuserTest, FusionClear_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
22 changes: 21 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ void swap(Fusion& a, Fusion& b) noexcept {
swap(a.expr_set_, b.expr_set_);
swap(a.val_deque_, b.val_deque_);

swap(a.lookup_index_, b.lookup_index_);

swap(a.val_type_name_map_, b.val_type_name_map_);
swap(a.expr_name_counter_, b.expr_name_counter_);

Expand Down Expand Up @@ -96,6 +98,13 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

to->lookup_index_ = from->lookup_index_;
for (auto& index_kv : to->lookup_index_) {
for (auto& kv : index_kv.second) {
kv.second = ir_cloner.clone(kv.second);
}
}

to->val_type_name_map_ = from->val_type_name_map_;
to->expr_name_counter_ = from->expr_name_counter_;

Expand Down Expand Up @@ -146,6 +155,8 @@ void Fusion::clear() noexcept {
val_deque_.clear();
expr_set_.clear();

lookup_index_.clear();

for (auto& kv : val_type_name_map_) {
kv.second = 0;
}
Expand Down Expand Up @@ -378,7 +389,10 @@ StmtNameType Fusion::registerVal(Val* val) {

val_set_.emplace(val);
val_deque_.push_back(val);
return getValName(*(val->getValType()));
const auto vtype = *val->getValType();
const auto name = getValName(vtype);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question just for clarification. When copying from Fusion 1 to Fusion 2, does the copying of the Vals need to proceed in the same order they are registered in Fusion 1 so that the name lookup will work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question: when cloning a fusion, the names will be copied directly w/o going through this registration path.

TORCH_INTERNAL_ASSERT(lookup_index_[vtype].insert({name, val}).second);
return name;
}

StmtNameType Fusion::registerExpr(Expr* expr) {
Expand Down Expand Up @@ -431,6 +445,12 @@ StmtNameType Fusion::registerStatement(Statement* stmt) {
return kInvalidStmName;
}

Val* Fusion::lookupValue(ValType vtype, StmtNameType name) const {
const auto& index = lookup_index_.at(vtype);
const auto it = index.find(name);
return it != index.end() ? it->second : nullptr;
}

void Fusion::resetTvUses() {
// getExprs only uses definition, so even if we've modified uses already to
// remove dead exprs, this could reinsert them. getExprs is also boundeds by
Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class TORCH_CUDA_CU_API Fusion final {

~Fusion();

friend void swap(Fusion& a, Fusion& b) noexcept;
TORCH_CUDA_CU_API friend void swap(Fusion& a, Fusion& b) noexcept;

void clear() noexcept;

Expand Down Expand Up @@ -116,6 +116,9 @@ class TORCH_CUDA_CU_API Fusion final {
//! Replace output with another value
void replaceOutput(Val* output, Val* replacement);

//! Lookup the value node with the specified type and name
Val* lookupValue(ValType vtype, StmtNameType name) const;

//! Clear Expr's from TV uses that are not required to produce outputs from
//! inputs
void resetTvUses();
Expand Down Expand Up @@ -216,6 +219,10 @@ class TORCH_CUDA_CU_API Fusion final {
std::deque<Val*> val_deque_;
std::unordered_set<Expr*> expr_set_;

// name-to-node lookup indexes
std::unordered_map<ValType, std::unordered_map<StmtNameType, Val*>>
lookup_index_;

// Values names counters
std::unordered_map<ValType, StmtNameType, TypeHash> val_type_name_map_;

Expand Down