Skip to content

Commit

Permalink
[TensorIR][M2a] Structural Error Reporting (apache#8121)
Browse files Browse the repository at this point in the history
This PR is part of the TensorIR upstreaming effort (apache#7527), stage M2a.

In this PR, we implemented ScheduleError, an error reporting mechanism for schedule primitives to report user-face error messages, with the functionality of rendering the TIR out in the TVM script syntax.

This set of APIs allows future improvement of error location rendering, e.g. more colorful rendering mechanisms like synr does.

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>
  • Loading branch information
7 people authored and Trevor Morris committed Jun 17, 2021
1 parent 718f2e3 commit 39adf51
Show file tree
Hide file tree
Showing 11 changed files with 296 additions and 9 deletions.
14 changes: 13 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
namespace tvm {
namespace tir {

/*! \brief The level of detailed error message rendering */
enum class ScheduleErrorRenderLevel : int32_t {
/*! \brief Render a detailed error message */
kDetail = 0,
/*! \brief Render the error in fast mode */
kFast = 1,
/*! \brief No error message at all */
kNone = 2,
};

/**************** Random variable: BlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR block */
Expand Down Expand Up @@ -209,13 +219,15 @@ class Schedule : public runtime::ObjectRef {
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

from . import schedule
from . import ir_builder
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .state import ScheduleDebugMask, ScheduleState
from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule
from .schedule import LoopRV, BlockRV, ExprRV, RAND_VAR_TYPE, Schedule, ScheduleError
22 changes: 22 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object
from tvm.tir import Block, For, IntImm, PrimFunc, Var
Expand All @@ -27,6 +28,11 @@
from .state import ScheduleState, StmtSRef


@register_error
class ScheduleError(TVMError):
"""Error that happens during TensorIR scheduling."""


@_register_object("tir.LoopRV")
class LoopRV(Object):
"""A random variable that refers to a loop"""
Expand Down Expand Up @@ -57,10 +63,14 @@ class Schedule(Object):
Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
"""

ERROR_RENDER_LEVEL = {"detail": 0, "fast": 1, "none": 2}

def __init__(
self,
func_or_mod: Union[PrimFunc, IRModule],
*,
debug_mode: Union[bool, int] = False,
error_render_level: str = "detail",
):
"""Construct a concrete TensorIR schedule from an IRModule or a PrimFunc
Expand All @@ -71,6 +81,11 @@ def __init__(
debug_mode : Union[bool, int]
Do extra correctness checking after the class creation and each time
scheduling primitive
error_render_level : str = "detail"
The level of error rendering. Choices: "detail", "fast", "none".
"detail": Render a detailed error message, with the TIR and error locations printed
"fast: Show a simple error message without rendering or string manipulation
"none": Do not show any error message.
Note
----------
Expand All @@ -85,10 +100,17 @@ def __init__(
debug_mode = 0
if not isinstance(debug_mode, int):
raise TypeError(f"`debug_mode` should be integer or boolean, but gets: {debug_mode}")
if error_render_level not in Schedule.ERROR_RENDER_LEVEL:
raise ValueError(
'error_render_level can be "detail", "fast", or "none", but got: '
+ f"{error_render_level}"
)
error_render_level = Schedule.ERROR_RENDER_LEVEL.get(error_render_level)
self.__init_handle_by_constructor__(
_ffi_api_schedule.ConcreteSchedule, # pylint: disable=no-member
func_or_mod,
debug_mode,
error_render_level,
)

########## Utilities ##########
Expand Down
71 changes: 67 additions & 4 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
namespace tvm {
namespace tir {

Schedule Schedule::Concrete(IRModule mod, int debug_mode) {
Schedule Schedule::Concrete(IRModule mod, int debug_mode,
ScheduleErrorRenderLevel error_render_level) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mode);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
return Schedule(std::move(n));
Expand Down Expand Up @@ -136,6 +138,7 @@ class ScheduleCopier {
scope->src2deps = Copy(old_info.scope->src2deps);
scope->dst2deps = Copy(old_info.scope->dst2deps);
scope->buffer_writers = Copy(old_info.scope->buffer_writers);
scope->stage_pipeline = old_info.scope->stage_pipeline;
new_info.scope = BlockScope(std::move(scope));
result[Copy(old_sref)] = std::move(new_info);
}
Expand Down Expand Up @@ -173,21 +176,81 @@ class ScheduleCopier {

void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const {
ScheduleCopier::Copy(this, new_state, new_symbol_table);
new_state->get()->DebugVerify();
}

Schedule ConcreteScheduleNode::Copy() const {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
Copy(&n->state_, &n->symbol_table_);
n->error_render_level_ = this->error_render_level_;
this->Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>();
return Schedule(std::move(n));
}

/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */
#define TVM_TIR_SCHEDULE_BEGIN() try {
/*!
* \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error
* message rendering
* \param level An ScheduleErrorRenderLevel enum, level of error rendering
* \sa ScheduleErrorRenderLevel
*/
#define TVM_TIR_SCHEDULE_END(level) \
} \
catch (const ScheduleError& error) { \
if ((level) == ScheduleErrorRenderLevel::kDetail) { \
throw tvm::runtime::Error(error.RenderReport()); \
} else if ((level) == ScheduleErrorRenderLevel::kFast) { \
throw tvm::runtime::Error(error.FastErrorString()); \
} else if ((level) == ScheduleErrorRenderLevel::kNone) { \
throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
} \
}

/******** Block/Loop relation ********/

BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
class NotSingleResult : public ScheduleError {
public:
explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks)
: name_(name), mod_(mod), blocks_{} {
blocks_.reserve(blocks.size());
for (const StmtSRef& block_sref : blocks) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
blocks_.push_back(GetRef<Block>(block));
}
}

String primitive() const final { return "get-block"; }
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; }

String DetailRenderTemplate() const final {
if (blocks_.empty()) {
return "Cannot find a block with the name: " + name_;
} else {
return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_;
}
}

String FastErrorString() const final {
if (blocks_.empty()) {
return "ScheduleError: Cannot find a block with the specified name";
} else {
return "ScheduleError: Found multiple blocks with the specified name";
}
}

String name_;
IRModule mod_;
Array<Block> blocks_;
};
Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name);
CHECK_EQ(blocks.size(), 1) << "ValueError: There are " << blocks.size()
<< " blocks with the name: " << name;
if (blocks.size() != 1) {
TVM_TIR_SCHEDULE_BEGIN();
throw NotSingleResult(name, this->state_->mod, blocks);
TVM_TIR_SCHEDULE_END(this->error_render_level_);
}
return CreateRV<BlockRV>(blocks[0]);
}

Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ class ConcreteScheduleNode : public ScheduleNode {
protected:
/*! \brief The internal state of scheduling */
ScheduleState state_;
/*! \brief The level of error rendering */
ScheduleErrorRenderLevel error_render_level_;
/*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */
TSymbolTable symbol_table_;
/*! \brief A persistent stateless arithmetic analyzer. */
std::unique_ptr<arith::Analyzer> analyzer_;

public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `error_render_level_` is not visited
// `state_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visitied
Expand Down
55 changes: 55 additions & 0 deletions src/tir/schedule/error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"

namespace tvm {
namespace tir {

String ScheduleError::RenderReport() const {
IRModule mod = this->mod();
std::ostringstream os;
os << "ScheduleError: An error occurred in the schedule primitive '" << this->primitive()
<< "'.\n\nThe IR is:\n"
<< AsTVMScript(mod);
Array<ObjectRef> locs = LocationsOfInterest();
int n_locs = locs.size();
std::vector<String> roi_names;
roi_names.reserve(n_locs);
if (n_locs > 0) {
os << "Regions of interest:\n";
for (const ObjectRef& obj : locs) {
String name = obj->GetTypeKey() + '#' + std::to_string(roi_names.size());
os << name << "\n" << obj;
roi_names.emplace_back(std::move(name));
}
os << "\n";
}
std::string msg = DetailRenderTemplate();
for (int i = 0; i < n_locs; ++i) {
std::string src = "{" + std::to_string(i) + "}";
for (size_t pos; (pos = msg.find(src)) != std::string::npos;) {
msg.replace(pos, src.length(), roi_names[i]);
}
}
os << "Error message: " << msg;
return os.str();
}

} // namespace tir
} // namespace tvm
60 changes: 60 additions & 0 deletions src/tir/schedule/error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_ERROR_H_
#define TVM_TIR_SCHEDULE_ERROR_H_

#include <tvm/tir/schedule/state.h>

namespace tvm {
namespace tir {

/*! \brief Error that happens during TensorIR scheduling */
class ScheduleError : public tvm::runtime::Error {
public:
/*! \brief Base constructor */
ScheduleError() : tvm::runtime::Error("") {}
/*! \brief The error occurred in this scheduling primitive */
virtual String primitive() const = 0;
/*! \brief The error occurred in this IRModule */
virtual IRModule mod() const = 0;
/*! \brief The locations of interest that we want to point out */
virtual Array<ObjectRef> LocationsOfInterest() const = 0;
/*!
* \brief Returns an error string template for rendering, corresponds to the "detail" mode.
* \sa ScheduleErrorRenderLevel
* \note The template is a string, e.g.
* "Some error occurred on block {0} and loop {1} blah blah"
* And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right
* now it only printed out all the locations in plain text, but in the future, we may want to mark
* the IR with underscores and attach names to each location of interest, like what synr does.
*/
virtual String DetailRenderTemplate() const = 0;
/*!
* \brief Returns an error string without needing to render, corresponds to the "fast" mode
* \sa ScheduleErrorRenderLevel
*/
virtual String FastErrorString() const = 0;
/*! \brief Render the ScheduleError with the template provided by `DetailRenderTemplate` */
String RenderReport() const;
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_ERROR_H_
5 changes: 3 additions & 2 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
/**************** (FFI) Constructor ****************/

TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
.set_body_typed([](ObjectRef obj, int debug_mode) -> Schedule {
.set_body_typed([](ObjectRef obj, int debug_mode, int error_render_level) -> Schedule {
IRModule mod{nullptr};
if (const auto* func = obj.as<PrimFuncNode>()) {
mod = IRModule({{GlobalVar("main"), GetRef<BaseFunc>(func)}});
Expand All @@ -66,7 +66,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
LOG(FATAL) << "TypeError: Expects `IRModule` or `PrimFunc`, but gets: "
<< obj->GetTypeKey();
}
return Schedule::Concrete(mod, debug_mode);
return Schedule::Concrete(mod, debug_mode,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
});

/******** (FFI) Lookup random variables ********/
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "../../printer/text_printer.h"
#include "../../runtime/thread_storage_scope.h"
#include "./analysis.h"
#include "./error.h"

namespace tvm {
namespace tir {
Expand Down
Loading

0 comments on commit 39adf51

Please sign in to comment.