diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 600c001e5dba..091519d1cc0d 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -456,6 +456,73 @@ TEST(NVFuserTest, KernelExprEvalBindings_CUDA) { checkIntValue(evaluator, d, -2); } +// Test name-to-node lookup in the Fusion IR +TEST(NVFuserTest, FusionValueLookup_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto scalar = new Double(-1.0); + auto tv1 = mul(tv0, scalar); + auto tv2 = add(tv0, new Double(3.0)); + auto tv3 = mul(tv0, new Double(2.0)); + auto tv4 = add(tv2, tv1); + auto tv5 = add(tv4, tv3); + auto tv6 = add(tv0, tv3); + + fusion.addOutput(tv5); + fusion.addOutput(tv6); + + // using the value's val type + ASSERT_EQ(fusion.lookupValue(*tv0->getValType(), tv0->name()), tv0); + ASSERT_EQ(fusion.lookupValue(*scalar->getValType(), scalar->name()), scalar); + + // explicit ValType + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv1->name()), tv1); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv2->name()), tv2); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv3->name()), tv3); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv4->name()), tv4); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv5->name()), tv5); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv6->name()), tv6); + + // misses + ASSERT_NE(fusion.lookupValue(ValType::Scalar, tv0->name()), tv0); + ASSERT_NE(fusion.lookupValue(ValType::TensorView, tv1->name()), tv0); + + // non-existent names + ASSERT_EQ(fusion.lookupValue(ValType::Scalar, 12345), nullptr); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, 12345), nullptr); + + Fusion copy(fusion); + + auto copy_tv1 = copy.lookupValue(ValType::TensorView, tv1->name()); + auto copy_tv2 = copy.lookupValue(ValType::TensorView, tv2->name()); + auto copy_tv3 = copy.lookupValue(ValType::TensorView, tv3->name()); + auto copy_tv4 = copy.lookupValue(ValType::TensorView, tv4->name()); + auto copy_tv5 = copy.lookupValue(ValType::TensorView, tv5->name()); + auto copy_tv6 = copy.lookupValue(ValType::TensorView, tv6->name()); + + swap(fusion, copy); + + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv1->name()), copy_tv1); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv2->name()), copy_tv2); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv3->name()), copy_tv3); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv4->name()), copy_tv4); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv5->name()), copy_tv5); + ASSERT_EQ(fusion.lookupValue(ValType::TensorView, tv6->name()), copy_tv6); + + fusion.clear(); + + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv1->name()), tv1); + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv2->name()), tv2); + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv3->name()), tv3); + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv4->name()), tv4); + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv5->name()), tv5); + ASSERT_EQ(copy.lookupValue(ValType::TensorView, tv6->name()), tv6); +} + TEST(NVFuserTest, FusionClear_CUDA) { Fusion fusion; FusionGuard fg(&fusion); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 192bed24a182..7183a8d65ac4 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -42,6 +42,8 @@ void swap(Fusion& a, Fusion& b) noexcept { swap(a.expr_set_, b.expr_set_); swap(a.val_deque_, b.val_deque_); + swap(a.lookup_index_, b.lookup_index_); + swap(a.val_type_name_map_, b.val_type_name_map_); swap(a.expr_name_counter_, b.expr_name_counter_); @@ -96,6 +98,13 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) { ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_)); } + to->lookup_index_ = from->lookup_index_; + for (auto& index_kv : to->lookup_index_) { + for (auto& kv : index_kv.second) { + kv.second = ir_cloner.clone(kv.second); + } + } + to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; @@ -146,6 +155,8 @@ void Fusion::clear() noexcept { val_deque_.clear(); expr_set_.clear(); + lookup_index_.clear(); + for (auto& kv : val_type_name_map_) { kv.second = 0; } @@ -378,7 +389,10 @@ StmtNameType Fusion::registerVal(Val* val) { val_set_.emplace(val); val_deque_.push_back(val); - return getValName(*(val->getValType())); + const auto vtype = *val->getValType(); + const auto name = getValName(vtype); + TORCH_INTERNAL_ASSERT(lookup_index_[vtype].insert({name, val}).second); + return name; } StmtNameType Fusion::registerExpr(Expr* expr) { @@ -431,6 +445,12 @@ StmtNameType Fusion::registerStatement(Statement* stmt) { return kInvalidStmName; } +Val* Fusion::lookupValue(ValType vtype, StmtNameType name) const { + const auto& index = lookup_index_.at(vtype); + const auto it = index.find(name); + return it != index.end() ? it->second : nullptr; +} + void Fusion::resetTvUses() { // getExprs only uses definition, so even if we've modified uses already to // remove dead exprs, this could reinsert them. getExprs is also boundeds by diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index b745130bb5d3..3680ff531e92 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -85,7 +85,7 @@ class TORCH_CUDA_CU_API Fusion final { ~Fusion(); - friend void swap(Fusion& a, Fusion& b) noexcept; + TORCH_CUDA_CU_API friend void swap(Fusion& a, Fusion& b) noexcept; void clear() noexcept; @@ -116,6 +116,9 @@ class TORCH_CUDA_CU_API Fusion final { //! Replace output with another value void replaceOutput(Val* output, Val* replacement); + //! Lookup the value node with the specified type and name + Val* lookupValue(ValType vtype, StmtNameType name) const; + //! Clear Expr's from TV uses that are not required to produce outputs from //! inputs void resetTvUses(); @@ -216,6 +219,10 @@ class TORCH_CUDA_CU_API Fusion final { std::deque val_deque_; std::unordered_set expr_set_; + // name-to-node lookup indexes + std::unordered_map> + lookup_index_; + // Values names counters std::unordered_map val_type_name_map_;