Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps #6073

Merged
merged 15 commits into from
Jul 21, 2020
228 changes: 203 additions & 25 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,24 @@ class State:
-----
This is a wrapper class of StateObject to deal with copy-on-write property
"""

# Static trans table for thread bind
# This is used to transform the annotation name to C++ enum
ANNOTATION_TRANS_TABLE = {
"none": 0,
"unroll": 1,
"vectorize": 2,
"parallel": 3,
"vthread": 4,
"blockIdx.x": 5,
"threadIdx.x": 6,
"blockIdx.y": 7,
"threadIdx.y": 8,
"blockIdx.z": 9,
"threadIdx.z": 10,
"tensorize": 11
}

def __init__(self, state_object, dag):
self.state_object = state_object
self.compute_dag = dag
Expand All @@ -108,20 +126,140 @@ def stage_ops(self):
"""
return [stage.op for stage in self.stages]

def bind(self, stage, iterator, thread_name):
""" Schedule primitive corresponds to te.bind.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be binded, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be binded.
thread_name : str
The thread type to be binded. Candidates:
- vthread
- blockIdx.x
- threadIdx.x
- blockIdx.y
- threadIdx.y
- blockIdx.z
- threadIdx.z

Returns
-------
res_it : Iterator
The binded Iterator.
"""
if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
raise ValueError("Invalid thread_name: ", thread_name)

self.state_object, res = _ffi_api.StateBind(self.state_object,
self._resolve_stage_id(stage), iterator,
State.ANNOTATION_TRANS_TABLE[thread_name])
return res

def parallel(self, stage, iterator):
""" Schedule primitive corresponds to te.parallel.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be paralleled, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be paralleled.

Returns
-------
res_it : Iterator
The paralleled Iterator.
"""
self.state_object, res = _ffi_api.StateParallel(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def unroll(self, stage, iterator, max_unroll=None):
""" Schedule primitive corresponds to te.unroll.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be unrolled, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be unrolled.
max_unroll : Optional[int]
The max unroll limit. Iterator with extent larger than this limit will be skipped.

Returns
-------
res_it : Iterator
The unrolled Iterator.
"""
self.state_object, res = _ffi_api.StateUnroll(self.state_object,
self._resolve_stage_id(stage), iterator,
max_unroll if max_unroll else -1)
return res

def vectorize(self, stage, iterator):
""" Schedule primitive corresponds to te.vectorize.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be vectorized, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be vectorized.

Returns
-------
res_it : Iterator
The vectorized Iterator.
"""
self.state_object, res = _ffi_api.StateVectorize(self.state_object,
self._resolve_stage_id(stage), iterator)
return res

def fuse(self, stage, iters):
""" Schedule primitive corresponds to te.fuse.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be fused, which can be specified by the integer index, Operation,
or output tensor of the stage.
iters : List[Iterator]
The iterators to be fused.

Returns
-------
res_it : Iterator
The fused Iterator.

Notes
-----
If the iterators to be fused have stages attached at them(by compute_at), the fused
result will become the new attach point.
"""
self.state_object, res = _ffi_api.StateFuse(self.state_object,
self._resolve_stage_id(stage), iters)
return res

def reorder(self, stage, order):
""" Schedule primitive corresponds to te.reorder.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be reordered, can be a Stage order index, Stage operation or stage
output tensor.
The Stage to be reordered, which can be specified by the integer index, Operation,
or output tensor of the stage.
order : List[Iterator]
Iterators in the expected order.
"""
stage_id = self._resolve_stage_id(stage)

self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
order)

def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.
Expand All @@ -132,8 +270,8 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, can be a Stage order index, Stage operation or stage
output tensor.
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to be split.
lengths: List[int]
Expand All @@ -144,34 +282,74 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
Returns
-------
res_its : List[Iterator]
The splitted new Iterators
"""
stage_id = self._resolve_stage_id(stage)
The splitted new Iterators.

self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths,
inner_to_outer)
Notes
-----
If we do split on an iterator which has stages attached at it(by compute_at), the inner
most iterator of split results will become the new attach point.
"""
self.state_object, res = _ffi_api.StateSplit(self.state_object,
self._resolve_stage_id(stage),
iterator, lengths, inner_to_outer)
return res

def fuse(self, stage, iters):
""" Schedule primitive corresponds to te.fuse.
def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be fused, can be a Stage order index, Stage operation or stage
output tensor.
iters : List[Iterator]
The iterators to be fused
The Stage to be compute at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_stage : Union[int, Operation, Tensor]
The target stage of compute_at, which can be specified by the integer index, Operation,
or output tensor of the stage.
target_iter : Iterator
The target Iterator of compute_at.

Notes
-----
After compute_at, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object = _ffi_api.StateComputeAt(self.state_object,
self._resolve_stage_id(stage),
self._resolve_stage_id(target_stage),
target_iter)

Returns
-------
res_it : Iterator
The fused Iterator
def compute_inline(self, stage):
""" Schedule primitive corresponds to te.compute_inline.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute inlined, which can be specified by the integer index, Operation,
or output tensor of the stage.
"""
stage_id = self._resolve_stage_id(stage)
self.state_object = _ffi_api.StateComputeInline(self.state_object,
self._resolve_stage_id(stage))

self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
return res
def compute_root(self, stage):
""" Schedule primitive corresponds to te.compute_root.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be compute root, which can be specified by the integer index, Operation,
or output tensor of the stage.

Notes
-----
After compute_root, we need careful dependency analysis to compute the accurate bound
information. However, it is relatively expensive and complicated, so we just fill "None"
as bound for the newly created iterators.
Call ComputeDAG::InferBound on the returned state to get the complete bound information.
"""
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))

def copy(self):
""" Do deep copy of this State. """
Expand Down
26 changes: 4 additions & 22 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,19 +270,9 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
}

// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
// Call each step's ApplyToSchedule method
// Note: some steps have extra parameters that must be passed and they may need different
// return value, so the ApplyToSchedule is not able to be merged to single interface
if (auto ps = step.as<ReorderStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else {
LOG(FATAL) << "Invalid Step";
}
StepApplyToSchedule(step, stages, stage_to_axes);
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -326,15 +316,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
if (auto ps = step.as<ReorderStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<SplitStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else if (auto ps = step.as<FuseStepNode>()) {
ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
} else {
LOG(FATAL) << "Invalid Step";
}
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
}

return ss.str();
Expand All @@ -352,7 +334,7 @@ State ComputeDAG::InferBound(const State& state) const {
ret_state = operator->()->init_state;
pstate = ret_state.CopyOnWrite();
pstate->transform_steps = state->transform_steps;
ret_state.DoSteps(*this);
ret_state.ApplySteps(*this);
} else {
ret_state = state;
pstate = ret_state.CopyOnWrite();
Expand Down
Loading