diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp index 06c2fcaf01547..44656b4df46f1 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.cpp @@ -416,6 +416,24 @@ std::shared_ptr MaxRootDomainInfoSpanningTree: return from_info; } +void SpanningTreePrinter::propagateTvPasC(TensorView* from, TensorView* to) { + stream_ << "propagateTvPasC" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateTvCasP(TensorView* from, TensorView* to) { + stream_ << "propagateTvCasP" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + +void SpanningTreePrinter::propagateTvSibling(TensorView* from, TensorView* to) { + stream_ << "propagateTvSibling" << std::endl; + stream_ << " from: " << from->toString() << std::endl; + stream_ << " to: " << to->toString() << std::endl; +} + } // namespace cuda } // namespace fuser } // namespace jit diff --git a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h index db32aaef6d23c..5a3ac6d46f479 100644 --- a/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h +++ b/torch/csrc/jit/codegen/cuda/maxinfo_propagator.h @@ -232,6 +232,17 @@ class TORCH_CUDA_CU_API MaxRootDomainInfoSpanningTree selector) {} }; +class TORCH_CUDA_CU_API SpanningTreePrinter + : public MaxInfoSpanningTree::Propagator { + std::ostream& stream_; + + public: + virtual void propagateTvPasC(TensorView* from, TensorView* to) override; + virtual void propagateTvCasP(TensorView* from, TensorView* to) override; + virtual void propagateTvSibling(TensorView* from, TensorView* to) override; + SpanningTreePrinter(std::ostream& stream = std::cout) : stream_(stream) {} +}; + } // namespace cuda } // namespace fuser } // namespace jit