diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 8306cb173e0a..d60a222ac265 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -19,7 +19,7 @@ /*! * \file tvm/tir/analysis.h - * \brief Analysis utilitie and passes for TIR. + * \brief Analysis utilities and passes for TIR. */ #ifndef TVM_TIR_ANALYSIS_H_ #define TVM_TIR_ANALYSIS_H_ @@ -220,6 +220,15 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, */ TVM_DLL Map> DetectBufferAccessLCA(const PrimFunc& func); +/*! + * \brief Verify if the given TIR is well-formed. The verification includes: + * - Check if expressions not contain vars that is defined outside the block. + * \param func The PrimFunc to be verified. + * \param assert_mode The indicator if it raises an error when the function is not well-formed. + * \return Whether it is a well-formed TIR function. + */ +TVM_DLL bool VerifyWellFormed(const PrimFunc& func, bool assert_mode = true); + // Pass variants of verification analysis // directly throws RuntimeError when verification fails. namespace transform { diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 7fc73ef4c436..13674daa2413 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -300,3 +300,23 @@ def apply_prim_func_arg_and_result_memory_constraints( return _ffi_api.ApplyPrimFuncArgAndResultMemoryConstraints( # type: ignore # pylint: disable=no-member func, relay_func_type, arg_and_result_memory_scopes ) + + +def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: + """Verify if the given TIR is well-formed. The verification includes: + - Check if expressions not contain vars that is defined outside the block. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be verified. + + assert_mode: bool + The indicator if it raises an error when the function is not well-formed. + + Returns + ------- + result: bool + Whether it is a well-formed TIR function. + """ + return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc new file mode 100644 index 000000000000..878618fbe6fd --- /dev/null +++ b/src/tir/analysis/verify_well_formed.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/verify_well_formed.cc + * \brief Check if schedulable tir is well-formed. + */ + +#include +#include +#include + +#include "../ir/functor_common.h" + +namespace tvm { +namespace tir { + +/*! \brief Verify all Expr inside the block does not contain: + * 1. loop vars outside the current block. + * 2. block vars of parent blocks. + */ +class BlockVarAccessVerifier : public StmtExprVisitor { + public: + static bool Verify(const PrimFunc& func, bool assert_mode) { + BlockVarAccessVerifier verifier(assert_mode); + verifier(func->body); + return !verifier.has_error_; + } + + private: + explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {} + + void VisitStmt(const Stmt& stmt) final { + if (!has_error_) { + StmtExprVisitor::VisitStmt(stmt); + } + } + + void VisitExpr(const PrimExpr& expr) final { + if (!has_error_) { + StmtExprVisitor::VisitExpr(expr); + } + } + + void VisitExpr_(const VarNode* op) final { + auto it = loop_vars_.find(op); + if (it != loop_vars_.end() && it->second < cur_block_level_) { + has_error_ = true; + if (assert_mode_) { + report_error(op); + } + } + } + + void VisitStmt_(const ForNode* op) final { + ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end()); + loop_vars_[op->loop_var.get()] = cur_block_level_; + StmtExprVisitor::VisitStmt_(op); + loop_vars_.erase(op->loop_var.get()); + } + + void VisitStmt_(const BlockNode* op) final { + // Do not check boundary if it's a opaque block. + cur_block_level_ += !op->iter_vars.empty(); + + // Step 0. Skip block iter var's domain + + // Step 1. Visit read/write regions + auto fvisit_buffer_region = [this](const BufferRegion& s) { + for (const auto& range : s->region) { + this->VisitExpr(range->min); + this->VisitExpr(range->extent); + } + }; + VisitArray(op->reads, fvisit_buffer_region); + VisitArray(op->writes, fvisit_buffer_region); + + // Step 2. Visit match buffers + VisitArray(op->match_buffers, + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + fvisit_buffer_region(match_buffer_region->source); + }); + + // Step 3. Visit init and body + if (op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); + + cur_block_level_ -= !op->iter_vars.empty(); + } + + private: + void report_error(const VarNode* var) { + // TODO(siyuan): use the error message from the parser. + LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint + << " is used inside the current block."; + } + + /*! \brief The map from outside loop vars to its corresponding block level. */ + std::unordered_map loop_vars_; + /*! \brief Whether it's in assert mode. */ + bool assert_mode_; + /*! \brief Current nested block stack level. */ + size_t cur_block_level_{0}; + /*! \brief Whether there is error. */ + bool has_error_{false}; +}; + +bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { + if (!BlockVarAccessVerifier::Verify(func, assert_mode)) { + return false; + } + // TODO(Siyuan): add more checks here. + return true; +} + +TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 3c11d2485332..dadabba48540 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -413,6 +413,7 @@ class StateCreator : private StmtVisitor { for (const auto& kv : n->mod->functions) { const BaseFunc& base_func = kv.second; if (const auto* func = base_func.as()) { + VerifyWellFormed(GetRef(func)); creator.VisitStmt(func->body); BlockInfoCollector::Collect(self, func->body); } diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py b/tests/python/unittest/test_tir_analysis_verify_well_formed.py new file mode 100644 index 000000000000..b3028a0148aa --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.script import tir as T + + +def test_pass_simple(): + @T.prim_func + def element_wise( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ): + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + # It's a opaque block , so it can use outside variables + C[i, j] = B[i, j] * 2.0 + + assert tvm.tir.analysis.verify_well_formed(element_wise) + + +def test_fail_use_out_loop_var(): + @T.prim_func + def element_wise( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + ): + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + # we cannot use `i` since it's defined outside the block + B[vi, vj] = A[i, vj] * 2.0 + + assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index 102b3d1cd710..9502da182926 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys import pytest import tvm import tvm.testing @@ -76,12 +75,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1) + B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) B_subregion0[()] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1) + B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) C[vi, vj] = B_subregion1[()] + 1.0 @@ -92,12 +91,12 @@ def element_wise_subregion_match_set_axis_separator(A: T.Buffer[(128, 128), "flo for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) B_subregion0[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[i, j], [], dtype="float32", offset_factor=1, axis_separators=[1]) + B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32", offset_factor=1, axis_separators=[1]) C[vi, vj] = B_subregion1[()] + T.float32(1) diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py index b2e8479462eb..adac81e62946 100644 --- a/tests/python/unittest/test_tir_schedule_set_scope.py +++ b/tests/python/unittest/test_tir_schedule_set_scope.py @@ -17,6 +17,7 @@ # pylint: disable=missing-function-docstring,missing-module-docstring import pytest import tvm +import tvm.testing from tvm import tir from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip @@ -59,12 +60,12 @@ def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1) + B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1) B_subregion0[()] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1) + B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1) C[vi, vj] = B_subregion1[()] + 1.0 @@ -75,12 +76,12 @@ def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion0_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1) + B_subregion0_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1) B_subregion0_shared[()] = A[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) - B_subregion1_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1) + B_subregion1_shared = T.match_buffer(B_shared[vi, vj], [], dtype="float32", scope="shared", offset_factor=1) C[vi, vj] = B_subregion1_shared[()] + T.float32(1) @@ -128,8 +129,4 @@ def test_set_scope_subregion(): if __name__ == "__main__": - test_set_scope() - test_set_scope_fail_on_output_buffer() - test_set_scope_fail_on_index_out_of_bound() - test_set_scope_fail_on_invalid_scope() - test_set_scope_subregion() + tvm.testing.main()