Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ANSOR] Auto-scheduler tutorial for GPU and necessary refactor/fix #6512

Merged
merged 13 commits into from
Sep 19, 2020
15 changes: 15 additions & 0 deletions docs/api/python/auto_scheduler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,20 @@ tvm.auto_scheduler.auto_schedule

.. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule

tvm.auto_scheduler.workload_registry
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: tvm.auto_scheduler.workload_registry.register_workload


tvm.auto_scheduler.measure
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.auto_scheduler.measure

.. autoclass:: tvm.auto_scheduler.measure.LocalRPCMeasureContext

.. autoclass:: tvm.auto_scheduler.measure.LocalRunner

.. autoclass:: tvm.auto_scheduler.measure.LocalBuilder

.. autoclass:: tvm.auto_scheduler.measure.RPCRunner
3 changes: 2 additions & 1 deletion include/tvm/auto_scheduler/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/node/node.h>

#include <string>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -191,7 +192,7 @@ class SearchPolicyNode : public Object {
* We store the string format of a state for redundancy check. This is used to make sure a
* measured state will never be measured again.
*/
std::unordered_set<String> measured_states_set_;
std::unordered_set<std::string> measured_states_set_;
/*! \brief The array of already measured states.
* The good states can be used as the initial population in evolutionary search. */
std::vector<State> measured_states_vector_;
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""
r"""
Distributed measurement infrastructure to measure the runtime costs of tensor programs.

These functions are responsible for building the tvm module, uploading it to
Expand All @@ -25,8 +25,8 @@
A builder builds the executable binary files and a runner runs the binary files to
get the measurement results. The flow of data structures is

`ProgramBuilder` `ProgramRunner`
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
. `ProgramBuilder` `ProgramRunner`
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`

We implement these in python to utilize python's multiprocessing and error handling.
"""
Expand Down Expand Up @@ -222,7 +222,7 @@ class LocalRunner(ProgramRunner):
where the first "1" is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms : int = 0
min_repeat_ms : int = 100
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
Expand All @@ -244,7 +244,7 @@ def __init__(
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
min_repeat_ms=100,
cooldown_interval=0.0,
enable_cpu_cache_flush=False,
):
Expand Down Expand Up @@ -289,7 +289,7 @@ class RPCRunner(ProgramRunner):
where the first "1" is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms : int = 0
min_repeat_ms : int = 100
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
Expand All @@ -316,7 +316,7 @@ def __init__(
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
min_repeat_ms=100,
cooldown_interval=0.0,
enable_cpu_cache_flush=False,
):
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SketchPolicy(SearchPolicy):
----------
task : SearchTask
The SearchTask for the computation declaration.
schedule_cost_model : CostModel = RandomModel()
program_cost_model : CostModel = RandomModel()
The cost model to estimate the complete schedules.
params : Optional[Dict[str, Any]]
Parameters of the search policy.
Expand Down Expand Up @@ -129,7 +129,7 @@ class SketchPolicy(SearchPolicy):
def __init__(
self,
task,
schedule_cost_model=RandomModel(),
program_cost_model=RandomModel(),
params=None,
seed=None,
verbose=1,
Expand All @@ -145,7 +145,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SketchPolicy,
task,
schedule_cost_model,
program_cost_model,
params,
seed or random.randint(1, 1 << 30),
verbose,
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def register_workload(func_name, f=None, override=False):

Examples
--------
@auto_scheduler.register_workload
def matmul(N, M, K):
A = te.placeholder((N, K), name='A')
B = te.placeholder((K, M), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
.. code-block:: python

@auto_scheduler.register_workload
def matmul(N, M, K):
A = te.placeholder((N, K), name='A')
B = te.placeholder((K, M), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
"""
global WORKLOAD_FUNC_REGISTRY

Expand Down
6 changes: 5 additions & 1 deletion python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

from .._ffi import get_global_func
from ..contrib import graph_runtime
from .base import _rpc_connect
from ..rpc import RPCSession
from .transport import TransportLogger

try:
from .base import _rpc_connect
except ImportError:
raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake")


class Session:
"""MicroTVM Device Session
Expand Down
Loading