Skip to content

Commit

Permalink
[ANSOR] Auto-scheduler tutorial for GPU and necessary refactor/fix (a…
Browse files Browse the repository at this point in the history
…pache#6512)

* add gpu tutorial

* refactor mutation in evolutionary search

* update

* update double matmul

* fix lint

* add double matmul test

* fix mutate compute location

* fix sketch search policy

* fix lint

* update

* address comments

* fix PruneInvalidStates
  • Loading branch information
merrymercy authored and trevor-m committed Oct 19, 2020
1 parent e1abea4 commit 766bd54
Show file tree
Hide file tree
Showing 19 changed files with 684 additions and 365 deletions.
15 changes: 15 additions & 0 deletions docs/api/python/auto_scheduler.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,20 @@ tvm.auto_scheduler.auto_schedule

.. autofunction:: tvm.auto_scheduler.auto_schedule.auto_schedule

tvm.auto_scheduler.workload_registry
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: tvm.auto_scheduler.workload_registry.register_workload


tvm.auto_scheduler.measure
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.auto_scheduler.measure

.. autoclass:: tvm.auto_scheduler.measure.LocalRPCMeasureContext

.. autoclass:: tvm.auto_scheduler.measure.LocalRunner

.. autoclass:: tvm.auto_scheduler.measure.LocalBuilder

.. autoclass:: tvm.auto_scheduler.measure.RPCRunner
3 changes: 2 additions & 1 deletion include/tvm/auto_scheduler/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#include <tvm/auto_scheduler/search_task.h>
#include <tvm/node/node.h>

#include <string>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -191,7 +192,7 @@ class SearchPolicyNode : public Object {
* We store the string format of a state for redundancy check. This is used to make sure a
* measured state will never be measured again.
*/
std::unordered_set<String> measured_states_set_;
std::unordered_set<std::string> measured_states_set_;
/*! \brief The array of already measured states.
* The good states can be used as the initial population in evolutionary search. */
std::vector<State> measured_states_vector_;
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""
r"""
Distributed measurement infrastructure to measure the runtime costs of tensor programs.
These functions are responsible for building the tvm module, uploading it to
Expand All @@ -25,8 +25,8 @@
A builder builds the executable binary files and a runner runs the binary files to
get the measurement results. The flow of data structures is
`ProgramBuilder` `ProgramRunner`
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
. `ProgramBuilder` `ProgramRunner`
`MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult`
We implement these in python to utilize python's multiprocessing and error handling.
"""
Expand Down Expand Up @@ -222,7 +222,7 @@ class LocalRunner(ProgramRunner):
where the first "1" is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms : int = 0
min_repeat_ms : int = 100
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
Expand All @@ -244,7 +244,7 @@ def __init__(
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
min_repeat_ms=100,
cooldown_interval=0.0,
enable_cpu_cache_flush=False,
):
Expand Down Expand Up @@ -289,7 +289,7 @@ class RPCRunner(ProgramRunner):
where the first "1" is warm up and will be discarded.
The returned result contains `repeat` costs,
each of which is an average of `number` costs.
min_repeat_ms : int = 0
min_repeat_ms : int = 100
The minimum duration of one `repeat` in milliseconds.
By default, one `repeat` contains `number` runs. If this parameter is set,
the parameters `number` will be dynamically adjusted to meet the
Expand All @@ -316,7 +316,7 @@ def __init__(
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
min_repeat_ms=100,
cooldown_interval=0.0,
enable_cpu_cache_flush=False,
):
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SketchPolicy(SearchPolicy):
----------
task : SearchTask
The SearchTask for the computation declaration.
schedule_cost_model : CostModel = RandomModel()
program_cost_model : CostModel = RandomModel()
The cost model to estimate the complete schedules.
params : Optional[Dict[str, Any]]
Parameters of the search policy.
Expand Down Expand Up @@ -129,7 +129,7 @@ class SketchPolicy(SearchPolicy):
def __init__(
self,
task,
schedule_cost_model=RandomModel(),
program_cost_model=RandomModel(),
params=None,
seed=None,
verbose=1,
Expand All @@ -145,7 +145,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.SketchPolicy,
task,
schedule_cost_model,
program_cost_model,
params,
seed or random.randint(1, 1 << 30),
verbose,
Expand Down
16 changes: 9 additions & 7 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ def register_workload(func_name, f=None, override=False):
Examples
--------
@auto_scheduler.register_workload
def matmul(N, M, K):
A = te.placeholder((N, K), name='A')
B = te.placeholder((K, M), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
.. code-block:: python
@auto_scheduler.register_workload
def matmul(N, M, K):
A = te.placeholder((N, K), name='A')
B = te.placeholder((K, M), name='B')
k = te.reduce_axis((0, K), name='k')
C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
"""
global WORKLOAD_FUNC_REGISTRY

Expand Down
6 changes: 5 additions & 1 deletion python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

from .._ffi import get_global_func
from ..contrib import graph_runtime
from .base import _rpc_connect
from ..rpc import RPCSession
from .transport import TransportLogger

try:
from .base import _rpc_connect
except ImportError:
raise ImportError("micro tvm is not enabled. Set USE_MICRO to ON in config.cmake")


class Session:
"""MicroTVM Device Session
Expand Down
Loading

0 comments on commit 766bd54

Please sign in to comment.