diff --git a/third_party/nvfuser/csrc/ir_cloner.h b/third_party/nvfuser/csrc/ir_cloner.h index 35016c3596e9..0e2ef3a58374 100644 --- a/third_party/nvfuser/csrc/ir_cloner.h +++ b/third_party/nvfuser/csrc/ir_cloner.h @@ -4,7 +4,9 @@ #include #include +#include #include +#include #include namespace nvfuser { @@ -28,6 +30,10 @@ class TORCH_CUDA_CU_API IrCloner { Statement* clone(const Statement* statement); + int64_t clone(int64_t x) { + return x; + } + template T* clone(const T* node) { return node ? clone(node->template as())->template as() @@ -35,9 +41,9 @@ class TORCH_CUDA_CU_API IrCloner { } template - std::vector clone(const std::vector& container) { + std::vector clone(const std::vector& container) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::vector copy; + std::vector copy; copy.reserve(container.size()); for (auto p : container) { copy.push_back(clone(p)); @@ -45,6 +51,23 @@ class TORCH_CUDA_CU_API IrCloner { return copy; } + template + std::unordered_set clone(const std::unordered_set& container) { + std::unordered_set copy; + copy.reserve(container.size()); + for (auto p : container) { + copy.insert(clone(p)); + } + return copy; + } + + template + std::tuple clone(const std::tuple& tup) { + return std::apply( + [this](auto&... x) { return std::make_tuple(clone(x)...); }, + tup); + } + IrContainer* container() const { return ir_container_; }