Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE] Record primitives of Schedule for visualization #14168

Merged
merged 3 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class Stage : public ObjectRef {
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
* \param sch The schedule which current stage belongs to
*/
explicit Stage(Operation op);
explicit Stage(Operation op, const ScheduleNode* sch);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -445,6 +446,26 @@ class Schedule : public ObjectRef {
using ContainerType = ScheduleNode;
};

/*!
* \brief Context helper to collect debug information of Schedule.
*
* Attach With<ScheduleContext>(schedule_instance, primitive_name)
* inside function body of schedule primitives to collect the
* snapshot of schedule status and corresponding primitive name
*/
class ScheduleContext {
private:
friend class With<ScheduleContext>;
ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name);
void EnterWithScope();
void ExitWithScope();

/*! \brief Schedule instance to store information for debug */
Schedule sch_;
/*! \brief String representing which primitive has been applied to sch_ */
String current_primitive_name_;
};

/*!
* \brief The schedule relation between IterVars
* can be Split, Fuse.
Expand Down Expand Up @@ -546,6 +567,8 @@ class StageNode : public Object {
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;
/*! \brief The schedule current stage is attached to */
const ScheduleNode* attach_sch;
chunit-quic marked this conversation as resolved.
Show resolved Hide resolved
/*! \brief The thread storage scope level of the stage */
std::string scope;
/*! \brief Whether this is an output stage */
Expand Down Expand Up @@ -615,12 +638,30 @@ class ScheduleNode : public Object {
* This is created on demand and can be invalidated.
*/
std::unordered_map<const Object*, Stage> op2stage_cache_;
/*!
* \brief list of all transformed schedules
* User can display the optimization strategy via TEDD step by step to check
* the order and effect of primitives. Set "te.keep_schedule_record" in
* PassContext config as true to enable recording.
*/
Array<Schedule> schedule_record;
/*!
* \brief list of all applied primitive names.
*/
Array<String> primitive_record;
/*!
* \brief Flag to keep schedule record or not.
*/
Optional<Bool> keep_schedule_record;

void VisitAttrs(AttrVisitor* v) {
v->Visit("outputs", &outputs);
v->Visit("stages", &stages);
v->Visit("groups", &groups);
v->Visit("stage_map", &stage_map);
v->Visit("schedule_record", &schedule_record);
v->Visit("primitive_record", &primitive_record);
v->Visit("keep_schedule_record", &keep_schedule_record);
}

/*! \brief Initialize temp cache. */
Expand Down
27 changes: 26 additions & 1 deletion python/tvm/contrib/tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,27 @@ def insert_dot_id(sch):
return sch


def itervar_equal(iv_a, iv_b):
"""A helper method that compares the equality of two iterative variables"""
# Adopt the following method to assure the equality between two itervars.
# The plain comparison might fail (i.e. iv_a == iv_b) after the change of
# domain bounds from InferBound.
def _var_equal(v_a, v_b):
condtions = [
v_a.name == v_b.name,
v_a.dtype == v_b.dtype,
v_a.type_annotation == v_b.type_annotation,
]
return all(c for c in condtions)

condtions = [
_var_equal(iv_a.var, iv_b.var),
iv_a.iter_type == iv_b.iter_type,
iv_a.thread_tag == iv_b.thread_tag,
]
return all(c for c in condtions)


class ObjectManager:
"""A helper class tracking schedule objects, e.g. stage, IterVar,
relationship, and tensor, to their DOM path."""
Expand All @@ -88,6 +109,10 @@ def __init__(self, sch):
self.dict[stage] = [stage_idx]
for itervar_idx, itervar in enumerate(stage.all_iter_vars):
self.dict[itervar] = [stage_idx, itervar_idx]
# the itervars of leaf should also be mapped to the original one
for leaf_iv in stage.leaf_iter_vars:
if itervar_equal(leaf_iv, itervar):
self.dict[leaf_iv] = [stage_idx, itervar_idx]
for rel_idx, rel in enumerate(stage.relations):
self.dict[rel] = [stage_idx, rel_idx]
for tensor_idx in range(stage.op.num_outputs):
Expand Down Expand Up @@ -289,7 +314,7 @@ def encode_itervars(stage, range_map):

def get_leaf_itervar_index(itervar, leaf_iv):
for leaf_index, ivar in enumerate(leaf_iv):
if ivar == itervar:
if itervar_equal(ivar, itervar):
return leaf_index
return -1

Expand Down
13 changes: 11 additions & 2 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
*/
Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var,
Array<Expr> args, Span span, const Target& target,
const Map<GlobalVar, BaseFunc>& lowered_functions) {
const Map<GlobalVar, BaseFunc>& lowered_functions,
const te::Schedule& sch = {}) {
auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);

// Add some metadata on top of the *original function* and invoke the callback so it can
Expand Down Expand Up @@ -730,6 +731,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target);
// Store generated Schedules of operator
if (sch.defined() && sch->keep_schedule_record) {
func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
}
this->process_fn_(func_with_metadata);
} else {
const auto* function_node = original_function.as<FunctionNode>();
Expand All @@ -738,6 +743,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var);
func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target);
// Store generated Schedules of operator
if (sch.defined() && sch->keep_schedule_record) {
func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
}
this->process_fn_(func_with_metadata);
}

Expand Down Expand Up @@ -926,7 +935,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
CachedFunc cfunc = compiler_->Lower(key);
ICHECK(cfunc.defined());
return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args),
call_node->span, target, cfunc->funcs->functions);
call_node->span, target, cfunc->funcs->functions, cfunc->schedule);
}
}

Expand Down
14 changes: 9 additions & 5 deletions src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
Array<Stage>& stages = (*this)->stages;
Stage op_stage = operator[](tensor->op);
size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
Stage cache_stage = Stage(cache->op);
cache_stage.set_scope(scope);
Stage cache_stage = Stage(cache->op, this->operator->());
ICHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos + 1, cache_stage);
// in order to obtain correct copy on schedule_record,
// make sure "set_scope" primitive is applied after stage being added
cache_stage.set_scope(scope);
(*this)->stage_map.Set(cache->op, cache_stage);
// Update group
cache_stage->group = op_stage->group;
Expand Down Expand Up @@ -266,10 +268,12 @@ Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::strin
// create schedule for new cached stage.
Array<Stage>& stages = sch->stages;
size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
Stage cache_stage = Stage(cache_op);
cache_stage.set_scope(scope);
Stage cache_stage = Stage(cache_op, sch.operator->());
ICHECK_LT(pos, stages.size());
stages.insert(stages.begin() + pos, cache_stage);
// in order to obtain correct copy on schedule_record,
// make sure "set_scope" primitive is applied after stage being added
cache_stage.set_scope(scope);
sch->stage_map.Set(cache_op, cache_stage);
// Update group
cache_stage->group = orig_stage->group;
Expand Down Expand Up @@ -892,7 +896,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f
Operation factor_op(n);
Array<Stage>& stages = (*this)->stages;
size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
Stage factor_stage = Stage(factor_op);
Stage factor_stage = Stage(factor_op, this->operator->());
factor_stage->relations = rels;
ICHECK_LT(stage_pos, stages.size());
stages.insert(stages.begin() + stage_pos, factor_stage);
Expand Down
Loading