Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps #6107

Merged
merged 12 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ def infer_bound_from_state(self, state):

Returns
-------
state : State
updated_state : State
The State with complete bound information.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
# Copy the stage_id_map from the original state
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(state, State):
for k, v in state.stage_id_map.items():
updated_state.stage_id_map[k] = v
return updated_state

def __hash__(self):
# TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
Expand Down
71 changes: 71 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,72 @@ def compute_root(self, stage):
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))

def cache_read(self, stage, scope_name, reader_stages):
""" Schedule primitive corresponds to te.schedule.cache_read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this step does?

Copy link
Member

@merrymercy merrymercy Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step does the same thing as te.schedule.cache_read does. We choose to add a pointer to te.schedule.cache_read instead of copying the docstring from it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make the pointer more clear. (e.g., say "see also te.schedule.cache_read")

Copy link
Contributor Author

@jcf94 jcf94 Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a pointer to te.schedule.cache_read, we may also add this to other steps later.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache read, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name to be set for the new added read stage.
reader_stages : List[Union[int, Operation, Tensor]]
The reader stages. Each of the list can be specified by the integer index, Operation,
or output tensor of the stage.

Returns
-------
new_stage_op : Operator
The Operator of the new added stage.

Notes
-----
Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
target stage).
"""
reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
self._resolve_stage_id(stage),
scope_name, reader_stage_ids,
self.compute_dag)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def cache_write(self, stage, scope_name):
""" Schedule primitive corresponds to te.schedule.cache_write.

merrymercy marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache write, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name to be set for the new added write stage.

Returns
-------
new_stage_op : Operator
The Operator of the new added stage.

Notes
-----
Cache write step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
This step will cache write all output tensors of the target stage.
"""
self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object,
self._resolve_stage_id(stage),
scope_name, self.compute_dag)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
Expand All @@ -371,6 +437,11 @@ def _update_stage_id_map(self):
for index, stage in enumerate(self.stages):
self.stage_id_map[stage.op] = index

def _apply_stage_id_offset(self, start_id, offset=1):
for key, value in self.stage_id_map.items():
if value >= start_id:
self.stage_id_map[key] = value + offset

def __getitem__(self, key):
if isinstance(key, Tensor):
key = key.op
Expand Down
39 changes: 19 additions & 20 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,24 +221,6 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}

// Update the te::stage to tir::IterVar axis mapping
void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
if (auto pop = stage->op.as<te::ComputeOpNode>()) {
Array<IterVar> axes;
for (const auto& axis : pop->axis) {
axes.push_back(axis);
}
for (const auto& axis : pop->reduce_axis) {
axes.push_back(axis);
}
stage_to_axes->Set(stage, std::move(axes));
} else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
{} // do nothing on Placeholder
} else {
LOG(FATAL) << "Invalid op " << stage->op;
}
}

std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
Expand Down Expand Up @@ -272,7 +254,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -316,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
}

return ss.str();
Expand Down Expand Up @@ -382,6 +364,23 @@ State ComputeDAG::InferBound(const State& state) const {
return ret_state;
}

ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
te::Schedule sch;
Array<te::Tensor> old_tensors;
std::tie(sch, old_tensors) = ApplySteps(transform_steps);

Array<te::Tensor> new_tensors;
for (auto stage : sch->stages) {
if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
for (auto i = 0; i < stage->op->num_outputs(); ++i) {
new_tensors.push_back(stage->op.output(i));
}
}
}

return ComputeDAG(new_tensors);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ComputeDAGNode*>(ref.get());
Expand Down
10 changes: 10 additions & 0 deletions src/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
* initial state and return an up-to-date ComputeDAG.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
* the replay process, for we only need to get the new ComputeDAG structure.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \return The up-to-date ComputeDAG.
*/
ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;

TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
Expand Down
58 changes: 58 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <utility>

#include "compute_dag.h"
#include "transform_step.h"
#include "utils.h"

Expand Down Expand Up @@ -151,6 +152,36 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
}
}

AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
AttachMap map = AttachMap(make_object<AttachMapNode>());
auto pmap = map.CopyOnWrite();
for (const auto& x : operator->()->stage_to_attach_iter) {
auto key = x.first;
if (key >= start_id) {
key += offset;
}
auto value = x.second;
if (value.first >= start_id) {
value.first += offset;
}
pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
}
for (const auto& x : operator->()->iter_to_attached_stages) {
auto key = x.first;
if (key.first >= start_id) {
key.first += offset;
}
auto value = x.second;
for (auto& i : value) {
if (i >= start_id) {
i += offset;
}
}
pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
}
return map;
}

/********** State **********/
State::State(const Array<te::Operation>& ops) {
auto node = make_object<StateNode>();
Expand Down Expand Up @@ -258,6 +289,19 @@ void State::compute_root(int stage_id) {
step->ApplyToState(this);
}

int State::cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}

int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}

void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";

Expand Down Expand Up @@ -430,6 +474,20 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
return state;
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
return Array<ObjectRef>{state, Integer(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const ComputeDAG& task_dag) {
int res = state.cache_write(stage_id, scope_name, task_dag);
return Array<ObjectRef>{state, Integer(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
});
Expand Down
59 changes: 47 additions & 12 deletions src/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ class AttachMap : public ObjectRef {
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);

/*!
* \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
* \param stage_id The index of the stage to be compute at.
*/
void DeleteStage(int stage_id);

/*!
* \brief Find the relations of original iterators in AttachMap, and update them with the new
* iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
Expand All @@ -195,6 +197,17 @@ class AttachMap : public ObjectRef {
void UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters);

/*!
* \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
* to stage indexes that are larger than the start_id. Used for steps that inserts net stages to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean "and update offset"?

jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* ComputeDAG(e.g. CacheRead/CacheWrite step).
* \param start_id The index threshold, stage indexes in AttachMap which are larger than this
* will be applied the extra offset.
* \param offset The index offset to be added to the stage index.
* \return The updated AttachMap after applying stage index offset.
*/
AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;

TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);

Expand Down Expand Up @@ -225,6 +238,13 @@ class StateNode : public Object {
* operation.
*/
AttachMap attach_map;
/*!
* \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this better? given the above methods it seems that current_compute_dag might in fact not be up-to-date, given that some scheduling steps modify the compute dag.

Copy link
Member

@merrymercy merrymercy Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dag is always up-to-date.
The comment in the above method says the "initial dag" may not be up-to-date. So we need to store a new up-to-date dag here.

* stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
* later).
* The default value is an empty ObjectRef. (means no modification to the original DAG)
*/
ObjectRef current_compute_dag;
/*!
* \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
* tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
Expand All @@ -239,15 +259,6 @@ class StateNode : public Object {

static constexpr const char* _type_key = "auto_scheduler.State";
TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);

private:
/*!
* \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
* stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
* later).
* The default value is an empty ObjectRef. (means no modification to the original DAG)
*/
ObjectRef current_compute_dag;
};

/*!
Expand Down Expand Up @@ -347,7 +358,7 @@ class State : public ObjectRef {

/*!
* \brief Schedule primitive corresponds to te.compute_at.
* \param stage_id The index of the stage to be reordered.
* \param stage_id The index of the stage to be compute at.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter The iterator in target stage that this step will compute at to.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
Expand All @@ -358,19 +369,43 @@ class State : public ObjectRef {
void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
* \brief Schedule primitive corresponds to te.compute_inline.
* \param stage_id The index of the stage to be reordered.
* \param stage_id The index of the stage to be compute inlined.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
*/
void compute_inline(int stage_id);
/*!
* \brief Schedule primitive corresponds to te.compute_root.
* \param stage_id The index of the stage to be reordered.
* \param stage_id The index of the stage to be compute root.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
void compute_root(int stage_id);

/********** Step APIs adding new stages **********/

/*!
* \brief Schedule primitive corresponds to te.schedule.cache_read.
* \param stage_id The index of the stage to be cache read.
* \param scope_name The scope name to be set for the new added read stage.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \param reader_stage_ids The indexes of reader stages.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \param dag The original ComputeDAG of this state.
* \note Cache read step will add an extra stage to the original ComputeDAG, a up-to-date
* ComputeDAG is stored in State's `current_compute_dag`.
*/
int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to te.schedule.cache_write.
* \param stage_id The index of the stage to be cache write.
* \param scope_name The scope name to be set for the new added write stage.
* \param dag The original ComputeDAG of this state.
* \note Cache write step will add an extra stage to the original ComputeDAG, a up-to-date
* ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);

TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
};
Expand Down
Loading