diff --git a/include/tvm/ir/si_builder.h b/include/tvm/ir/si_builder.h index 57ce4563d7193..ab5f2d450fe47 100644 --- a/include/tvm/ir/si_builder.h +++ b/include/tvm/ir/si_builder.h @@ -34,8 +34,8 @@ namespace tvm { /*! - * \brief SIBuilder provides helper APIs for filling spans, - * particularly useful for one-to-many, many-to-one and many-to-many pass transformations. + * \brief Source Information Builder, SIBuilder provides helper APIs for filling spans, + * particularly useful for one-to-many, many-to-one and many-to-many IR transformations. */ class SIBuilder { public: @@ -68,11 +68,11 @@ class SIBuilder { SIBuilder& operator=(const SIBuilder&) = delete; /*! - * \brief create new source info based on the given span or subgraph. + * \brief build a span of source information, which is based on the given span or subgraph. * - * \return The given span, or reconstructed span from subgraph. + * \return the built span */ - Span CreateSpan() const; + Span Build() const; /*! * \brief Recursively fill all span of exprs in subgraph from entry until inputs. diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index f52eb97704a35..21a5ed6576756 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -71,7 +71,10 @@ def __init__(self, source_name, line, end_line, column, end_column): @register_object("SequentialSpan") class SequentialSpan(Object): - """Specifies a location in a source program. + """A sequence of source spans + + This span is specific for an expression, which is from multiple expressions + after an IR transform. Parameters ---------- diff --git a/src/ir/si_builder.cc b/src/ir/si_builder.cc index e149c29001283..4c4bca4560719 100644 --- a/src/ir/si_builder.cc +++ b/src/ir/si_builder.cc @@ -33,11 +33,12 @@ using RelayExprSet = std::unordered_set; using StmtSet = std::unordered_set; -class RelayCollapse : public relay::ExprVisitor { +class RelayCollectSpans : public relay::ExprVisitor { public: - explicit RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {} + explicit RelayCollectSpans(const RelayExprSet& inputs = {}) : inputs_(inputs) {} - Span Collapse(const relay::Expr& entry); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const relay::Expr& entry); void VisitExpr(const relay::Expr& expr) final; @@ -46,7 +47,7 @@ class RelayCollapse : public relay::ExprVisitor { const RelayExprSet& inputs_; }; -void RelayCollapse::VisitExpr(const relay::Expr& expr) { +void RelayCollectSpans::VisitExpr(const relay::Expr& expr) { if (visit_counter_.count(expr.get())) { return; } @@ -61,7 +62,7 @@ void RelayCollapse::VisitExpr(const relay::Expr& expr) { relay::ExprVisitor::VisitExpr(expr); } -Span RelayCollapse::Collapse(const relay::Expr& entry) { +Span RelayCollectSpans::CollectSpans(const relay::Expr& entry) { VisitExpr(entry); return SequentialSpan(spans_); } @@ -71,6 +72,7 @@ class RelayRecursivelyFill : public relay::ExprMutator { explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {}) : span_(span), inputs_(inputs) {} + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. void Fill(const relay::Expr& entry); relay::Expr VisitExpr(const relay::Expr& expr) final; @@ -94,9 +96,9 @@ relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) { void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); } -class TirCollapse : public tir::StmtExprVisitor { +class TirCollectSpans : public tir::StmtExprVisitor { public: - explicit TirCollapse(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {}) + explicit TirCollectSpans(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {}) : expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} void VisitExpr(const PrimExpr& expr) final; @@ -105,8 +107,10 @@ class TirCollapse : public tir::StmtExprVisitor { bool IsInput(const PrimExpr& expr); bool IsInput(const tir::Stmt& stmt); - Span Collapse(const PrimExpr& expr); - Span Collapse(const tir::Stmt& stmt); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const PrimExpr& expr); + // From entry to inputs, recursively collect spans. The spans of inputs are included. + Span CollectSpans(const tir::Stmt& stmt); private: Array spans_; @@ -115,25 +119,25 @@ class TirCollapse : public tir::StmtExprVisitor { const StmtSet& stmt_inputs_; }; -Span TirCollapse::Collapse(const PrimExpr& expr) { +Span TirCollectSpans::CollectSpans(const PrimExpr& expr) { operator()(expr); return SequentialSpan(spans_); } -Span TirCollapse::Collapse(const tir::Stmt& stmt) { +Span TirCollectSpans::CollectSpans(const tir::Stmt& stmt) { operator()(stmt); return SequentialSpan(spans_); } -bool TirCollapse::IsInput(const PrimExpr& expr) { +bool TirCollectSpans::IsInput(const PrimExpr& expr) { return expr_inputs_.find(expr) != expr_inputs_.end(); } -bool TirCollapse::IsInput(const tir::Stmt& stmt) { +bool TirCollectSpans::IsInput(const tir::Stmt& stmt) { return stmt_inputs_.find(stmt) != stmt_inputs_.end(); } -void TirCollapse::VisitExpr(const PrimExpr& expr) { +void TirCollectSpans::VisitExpr(const PrimExpr& expr) { if (visit_counter_.count(expr.get())) { return; } @@ -148,7 +152,7 @@ void TirCollapse::VisitExpr(const PrimExpr& expr) { StmtExprVisitor::VisitExpr(expr); } -void TirCollapse::VisitStmt(const tir::Stmt& stmt) { +void TirCollectSpans::VisitStmt(const tir::Stmt& stmt) { if (visit_counter_.count(stmt.get())) { return; } @@ -169,7 +173,9 @@ class TirRecursivelyFill : public tir::StmtExprMutator { const StmtSet& stmt_inputs = {}) : span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {} + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); } + // From entry until inputs, recursively fill spans into expressions. Inputs are not filled. PrimExpr Fill(const PrimExpr& e) { return operator()(e); } bool IsInput(const PrimExpr& expr); @@ -209,20 +215,20 @@ PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) { } struct SIBuilder::Impl { - virtual Span CreateSpan() const = 0; - virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const = 0; - virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const = 0; - virtual void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) = 0; - virtual void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) = 0; - virtual void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) = 0; - virtual void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) = 0; + virtual Span Build() const { return Span(); } + virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {}; + virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {}; + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {}; + virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {}; + virtual void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) {}; + virtual void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) {}; + virtual void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) {}; + virtual void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) {}; }; SIBuilder::~SIBuilder() = default; -Span SIBuilder::CreateSpan() const { return impl_->CreateSpan(); } +Span SIBuilder::Build() const { return impl_->Build(); } template <> void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const { @@ -243,54 +249,32 @@ void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& input } std::unique_ptr SIBuilder::CreateImpl(const Span& span) { - struct NullImpl : public SIBuilder::Impl { - Span CreateSpan() const final { return Span(); } - - void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final{}; - void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final{}; - void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final{}; - void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final{}; - void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final{}; - void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final{}; - void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final{}; - void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final{}; - }; - struct Impl : public SIBuilder::Impl { explicit Impl(const Span& span) : span_(span) {} - - Span CreateSpan() const final { return span_; } - + Span Build() const final { return span_; } void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final { - RelayRecursivelyFill(CreateSpan(), inputs).Fill(entry); + RelayRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + TirRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), inputs).Fill(entry); + TirRecursivelyFill(Build(), inputs).Fill(entry); } - void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final { - TirRecursivelyFill(CreateSpan(), {}, inputs).Fill(entry); + TirRecursivelyFill(Build(), {}, inputs).Fill(entry); } - - void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final { - span_ = RelayCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) final { + span_ = RelayCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final { - span_ = TirCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) final { + span_ = TirCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final { - span_ = TirCollapse(inputs).Collapse(entry); + void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final { + span_ = TirCollectSpans(inputs).CollectSpans(entry); } - - void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final { - span_ = TirCollapse({}, inputs).Collapse(entry); + void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) final { + span_ = TirCollectSpans({}, inputs).CollectSpans(entry); } private: @@ -305,7 +289,7 @@ std::unique_ptr SIBuilder::CreateImpl(const Span& span) { return std::make_unique(span); } - return std::make_unique(); + return std::make_unique(); } SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {} @@ -316,23 +300,23 @@ SIBuilder::SIBuilder(const std::initializer_list& init) template <> SIBuilder::SIBuilder(const relay::Expr& expr, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(expr, RelayExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(expr, RelayExprSet(inputs.begin(), inputs.end())); } template <> SIBuilder::SIBuilder(const PrimExpr& expr, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(expr, PrimExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(expr, PrimExprSet(inputs.begin(), inputs.end())); } SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(s, PrimExprSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(s, PrimExprSet(inputs.begin(), inputs.end())); } SIBuilder::SIBuilder(const tir::Stmt& s, const Array& inputs) : impl_(CreateImpl(Span())) { - impl_->CollapseSpan(s, StmtSet(inputs.begin(), inputs.end())); + impl_->CollectSpansSpan(s, StmtSet(inputs.begin(), inputs.end())); } // Register build pipeline related options diff --git a/tests/cpp/si_builder_test.cc b/tests/cpp/si_builder_test.cc index 4bbd1acd83932..f65debaa6b178 100644 --- a/tests/cpp/si_builder_test.cc +++ b/tests/cpp/si_builder_test.cc @@ -103,7 +103,7 @@ TEST(SIBuilder, CreateSapn) { Span span_1 = _CreateSpan("first"); { SIBuilder si_builder(span_1); - EXPECT_EQ(span_1, si_builder.CreateSpan()); + EXPECT_EQ(span_1, si_builder.Build()); } Span span_2 = _CreateSpan("second"); @@ -114,9 +114,9 @@ TEST(SIBuilder, CreateSapn) { SIBuilder si_builder_2({span_1, span_2}); SIBuilder si_builder_3{span_1, span_2}; - Span created_span_1 = si_builder_1.CreateSpan(); - Span created_span_2 = si_builder_2.CreateSpan(); - Span created_span_3 = si_builder_3.CreateSpan(); + Span created_span_1 = si_builder_1.Build(); + Span created_span_2 = si_builder_2.Build(); + Span created_span_3 = si_builder_3.Build(); auto created_seq_span_1 = created_span_1.as(); auto created_seq_span_2 = created_span_2.as(); @@ -140,7 +140,7 @@ TEST(SIBuilder, DisableSIBuilder) { Span span_1 = _CreateSpan("first"); { SIBuilder si_builder(span_1); - EXPECT_NE(span_1, si_builder.CreateSpan()); + EXPECT_NE(span_1, si_builder.Build()); } } @@ -179,7 +179,7 @@ TEST(SIBuilder, RelayRecursivelyFill) { checker.Check(z, expected_z); } -TEST(SIBuilder, RelayCollapse) { +TEST(SIBuilder, RelayCollectSpans) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -206,7 +206,7 @@ TEST(SIBuilder, RelayCollapse) { relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}, z_node_span); SIBuilder si_builder(z, {a}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 4); for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) { @@ -214,7 +214,7 @@ TEST(SIBuilder, RelayCollapse) { } } -TEST(SIBuilder, TirCollapsePrimExpr) { +TEST(SIBuilder, TirCollectSpansPrimExpr) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -241,7 +241,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) { z->span = z_node_span; SIBuilder si_builder(z, {x}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 4); @@ -250,7 +250,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) { } } -TEST(SIBuilder, TirCollapseStmtWithPrimInput) { +TEST(SIBuilder, TirCollectSpansStmtWithPrimInput) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -274,7 +274,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) { auto stmt = fmaketest(); stmt->span = stmt_node_span; SIBuilder si_builder(stmt, {x}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 3); @@ -283,7 +283,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) { } } -TEST(SIBuilder, TirCollapseStmtWithStmtInput) { +TEST(SIBuilder, TirCollectSpansStmtWithStmtInput) { using namespace tvm; auto pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("ir.enable_si_builder", Bool(true)); @@ -300,7 +300,7 @@ TEST(SIBuilder, TirCollapseStmtWithStmtInput) { tir::Block block({}, {}, {}, "block", body, init, Array(), Array(), Map(), block_node_span); SIBuilder si_builder(block, {init}); - Span created_span = si_builder.CreateSpan(); + Span created_span = si_builder.Build(); auto created_seq_span = created_span.as(); EXPECT_EQ(created_seq_span->spans.size(), 3);