diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index b85fdec8cba9..2aee2cb136b3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -24,6 +24,16 @@ namespace tvm { namespace tir { +/*! \brief The level of detailed error message rendering */ +enum class ScheduleErrorRenderLevel : int32_t { + /*! \brief Render a detailed error message */ + kDetail = 0, + /*! \brief Render the error in fast mode */ + kFast = 1, + /*! \brief No error message at all */ + kNone = 2, +}; + /**************** Random variable: BlockRV ****************/ /*! \brief A random variable that evaluates to a TensorIR block */ @@ -209,13 +219,15 @@ class Schedule : public runtime::ObjectRef { * \param mod The IRModule to be scheduled * \param debug_mode Do extra correctness checking after the class creation * and each time after calling the Replace method. + * \param error_render_level The level of error rendering * \return The concrete schedule created * \sa ScheduleDebugMask * \note The checks performed includes: * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode); + TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index afe521a74361..eb200df0c599 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -48,7 +48,7 @@ from .op import comm_reducer, min, max, sum from .op import q_multiply_shift -from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule +from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError from . import schedule from . import ir_builder diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5550a9e3c74f..ef1cab1fb663 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -19,4 +19,4 @@ from .block_scope import BlockScope, Dependency, DepKind, StmtSRef from .state import ScheduleDebugMask, ScheduleState -from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule +from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index f207fa274212..d420f7d32db0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from tvm._ffi import register_object as _register_object +from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object from tvm.tir import Block, For, IntImm, PrimFunc, Var @@ -27,6 +28,11 @@ from .state import ScheduleState, StmtSRef +@register_error +class ScheduleError(TVMError): + """Error that happens during TensorIR scheduling.""" + + @_register_object("tir.LoopRV") class LoopRV(Object): """A random variable that refers to a loop""" @@ -57,10 +63,14 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ + ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2} + def __init__( self, func_or_mod: Union[PrimFunc, IRModule], + *, debug_mode: Union[bool, int] = False, + error_render_level: str = "detail", ): """Construct a concrete TensorIR schedule from an IRModule or a PrimFunc @@ -71,6 +81,11 @@ def __init__( debug_mode : Union[bool, int] Do extra correctness checking after the class creation and each time scheduling primitive + error_render_level : str = "detail" + The level of error rendering. Choices: "detail", "fast", "none". + "detail": Render a detailed error message, with the TIR and error locations printed + "fast: Show a simple error message without rendering or string manipulation + "none": Do not show any error message. Note ---------- @@ -85,10 +100,17 @@ def __init__( debug_mode = 0 if not isinstance(debug_mode, int): raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}") + if error_render_level not in Schedule.ERROR_RENDER_LEVEL: + raise ValueError( + 'error_render_level can be "detail", "fast", or "none", but got: ' + + f"{error_render_level}" + ) + error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level) self.__init_handle_by_constructor__( _ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member func_or_mod, debug_mode, + error_render_level, ) ########## Utilities ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ef12f10fa924..60ab7920c37b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -21,9 +21,11 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mode) { +Schedule Schedule::Concrete(IRModule mod, int debug_mode, + ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mode); + n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); @@ -136,6 +138,7 @@ class ScheduleCopier { scope->src2deps = Copy(old_info.scope->src2deps); scope->dst2deps = Copy(old_info.scope->dst2deps); scope->buffer_writers = Copy(old_info.scope->buffer_writers); + scope->stage_pipeline = old_info.scope->stage_pipeline; new_info.scope = BlockScope(std::move(scope)); result[Copy(old_sref)] = std::move(new_info); } @@ -173,21 +176,81 @@ class ScheduleCopier { void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { ScheduleCopier::Copy(this, new_state, new_symbol_table); + new_state->get()->DebugVerify(); } Schedule ConcreteScheduleNode::Copy() const { ObjectPtr n = make_object(); - Copy(&n->state_, &n->symbol_table_); + n->error_render_level_ = this->error_render_level_; + this->Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); return Schedule(std::move(n)); } +/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */ +#define TVM_TIR_SCHEDULE_BEGIN() try { +/*! + * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error + * message rendering + * \param level An ScheduleErrorRenderLevel enum, level of error rendering + * \sa ScheduleErrorRenderLevel + */ +#define TVM_TIR_SCHEDULE_END(level) \ + } \ + catch (const ScheduleError& error) { \ + if ((level) == ScheduleErrorRenderLevel::kDetail) { \ + throw tvm::runtime::Error(error.RenderReport()); \ + } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ + throw tvm::runtime::Error(error.FastErrorString()); \ + } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ + throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ + } \ + } + /******** Block/Loop relation ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { + class NotSingleResult : public ScheduleError { + public: + explicit NotSingleResult(String name, IRModule mod, const Array& blocks) + : name_(name), mod_(mod), blocks_{} { + blocks_.reserve(blocks.size()); + for (const StmtSRef& block_sref : blocks) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + blocks_.push_back(GetRef(block)); + } + } + + String primitive() const final { return "get-block"; } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } + + String DetailRenderTemplate() const final { + if (blocks_.empty()) { + return "Cannot find a block with the name: " + name_; + } else { + return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_; + } + } + + String FastErrorString() const final { + if (blocks_.empty()) { + return "ScheduleError: Cannot find a block with the specified name"; + } else { + return "ScheduleError: Found multiple blocks with the specified name"; + } + } + + String name_; + IRModule mod_; + Array blocks_; + }; Array blocks = tir::GetBlocks(this->state_, name, func_name); - CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size() - << " blocks with the name: " << name; + if (blocks.size() != 1) { + TVM_TIR_SCHEDULE_BEGIN(); + throw NotSingleResult(name, this->state_->mod, blocks); + TVM_TIR_SCHEDULE_END(this->error_render_level_); + } return CreateRV(blocks[0]); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 39eab1159db9..ab467cec9ee3 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -37,6 +37,8 @@ class ConcreteScheduleNode : public ScheduleNode { protected: /*! \brief The internal state of scheduling */ ScheduleState state_; + /*! \brief The level of error rendering */ + ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ @@ -44,6 +46,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: void VisitAttrs(tvm::AttrVisitor* v) { + // `error_render_level_` is not visited // `state_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visitied diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc new file mode 100644 index 000000000000..f64d4aeb984b --- /dev/null +++ b/src/tir/schedule/error.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +String ScheduleError::RenderReport() const { + IRModule mod = this->mod(); + std::ostringstream os; + os << "ScheduleError: An error occurred in the schedule primitive '" << this->primitive() + << "'.\n\nThe IR is:\n" + << AsTVMScript(mod); + Array locs = LocationsOfInterest(); + int n_locs = locs.size(); + std::vector roi_names; + roi_names.reserve(n_locs); + if (n_locs > 0) { + os << "Regions of interest:\n"; + for (const ObjectRef& obj : locs) { + String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size()); + os << name << "\n" << obj; + roi_names.emplace_back(std::move(name)); + } + os << "\n"; + } + std::string msg = DetailRenderTemplate(); + for (int i = 0; i < n_locs; ++i) { + std::string src = "{" + std::to_string(i) + "}"; + for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { + msg.replace(pos, src.length(), roi_names[i]); + } + } + os << "Error message: " << msg; + return os.str(); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h new file mode 100644 index 000000000000..1031672f0010 --- /dev/null +++ b/src/tir/schedule/error.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_ERROR_H_ +#define TVM_TIR_SCHEDULE_ERROR_H_ + +#include + +namespace tvm { +namespace tir { + +/*! \brief Error that happens during TensorIR scheduling */ +class ScheduleError : public tvm::runtime::Error { + public: + /*! \brief Base constructor */ + ScheduleError() : tvm::runtime::Error("") {} + /*! \brief The error occurred in this scheduling primitive */ + virtual String primitive() const = 0; + /*! \brief The error occurred in this IRModule */ + virtual IRModule mod() const = 0; + /*! \brief The locations of interest that we want to point out */ + virtual Array LocationsOfInterest() const = 0; + /*! + * \brief Returns an error string template for rendering, corresponds to the "detail" mode. + * \sa ScheduleErrorRenderLevel + * \note The template is a string, e.g. + * "Some error occurred on block {0} and loop {1} blah blah" + * And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right + * now it only printed out all the locations in plain text, but in the future, we may want to mark + * the IR with underscores and attach names to each location of interest, like what synr does. + */ + virtual String DetailRenderTemplate() const = 0; + /*! + * \brief Returns an error string without needing to render, corresponds to the "fast" mode + * \sa ScheduleErrorRenderLevel + */ + virtual String FastErrorString() const = 0; + /*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */ + String RenderReport() const; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_ERROR_H_ diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index b407b07e5312..a1a4f09a7525 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // /**************** (FFI) Constructor ****************/ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule { + .set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule { IRModule mod{nullptr}; if (const auto* func = obj.as()) { mod = IRModule({{GlobalVar("main"), GetRef(func)}}); @@ -66,7 +66,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: " << obj->GetTypeKey(); } - return Schedule::Concrete(mod, debug_mode); + return Schedule::Concrete(mod, debug_mode, + static_cast(error_render_level)); }); /******** (FFI) Lookup random variables ********/ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index b72fd8e05706..e7c73120c730 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -35,6 +35,7 @@ #include "../../printer/text_printer.h" #include "../../runtime/thread_storage_scope.h" #include "./analysis.h" +#include "./error.h" namespace tvm { namespace tir { diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py new file mode 100644 index 000000000000..1fa658feabe3 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_tir_schedule_error_detail(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="detail") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the name: wrong_name" in msg + + +def test_tir_schedule_error_fast(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="fast") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "Cannot find a block with the specified name" in msg + + +def test_tir_schedule_error_none(): + sch = tir.Schedule(matmul, debug_mode=True, error_render_level="none") + with pytest.raises(tir.ScheduleError) as excinfo: + sch.get_block("wrong_name") + (msg,) = excinfo.value.args + assert "(not rendered)" in msg + + +if __name__ == "__main__": + test_tir_schedule_error_detail() + test_tir_schedule_error_fast() + test_tir_schedule_error_none()