Skip to content

Commit

Permalink
Add RPCRunner & OpenCL/CUDA test (apache#12)
Browse files Browse the repository at this point in the history
* Add RPCRunner & OpenCL search test

* Add CUDA search test

* Add RPCRunner test
  • Loading branch information
jcf94 authored and merrymercy committed Jun 20, 2020
1 parent 43d1530 commit f367d15
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@
from .compute_dag import ComputeDAG
from .task import SearchTask, MetaTileRewritePolicy, TuneOption
from .task import auto_schedule
from .measure import MeasureInput, LocalBuilder, LocalRunner
from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner
from .cost_model import RandomModel
from .serialization import LogToFile, LogReader, best_measure_pair_in_file
22 changes: 22 additions & 0 deletions python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,28 @@ def __init__(self,
_ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)


@tvm._ffi.register_object("ansor.RPCRunner")
class RPCRunner(Runner):
def __init__(self, key, host, port, priority=1,
n_parallel=1,
timeout=10,
number=3,
repeat=1,
min_repeat_ms=0,
cooldown_interval=0.0):
self.__init_handle_by_constructor__(
_ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel,
number, repeat, min_repeat_ms, cooldown_interval)

if check_remote(key, host, port, priority, timeout):
logger.info("Get devices for measurement successfully!")
else:
raise RuntimeError("Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status.")


MAX_ERROR_MSG_LEN = 512


Expand Down
3 changes: 2 additions & 1 deletion python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def __init__(self,
cmd = [sys.executable,
"-m", "tvm.exec.rpc_server",
"--host=%s" % host,
"--port=%s" % port]
"--port=%s" % port,
"--port-end=%s" % port_end]
if tracker_addr:
assert key
cmd += ["--tracker=%s:%d" % tracker_addr,
Expand Down
8 changes: 8 additions & 0 deletions src/ansor/measure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,13 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner")
cooldown_interval);
});

TVM_REGISTER_GLOBAL("ansor.RPCRunner")
.set_body_typed([](const std::string& key, const std::string& host, int port,
int priority, int timeout, int n_parallel, int number,
int repeat, int min_repeat_ms, double cooldown_interval) {
return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel,
number, repeat, min_repeat_ms, cooldown_interval);
});

} // namespace ansor
} // namespace tvm
1 change: 0 additions & 1 deletion src/ansor/search_policy/meta_tile_rewrite_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {

SearchTask cur_task_; // The current task

friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT
protected:
// Pick states from best states and random states with eps-greedy policy
void PickStatesWithEpsGreedy(std::vector<MeasureInput>* inputs,
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_ansor_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import tvm
from tvm import ansor
from tvm.rpc.tracker import Tracker
from tvm.rpc.server import Server
import tempfile

from test_ansor_common import get_tiled_matmul
Expand Down Expand Up @@ -62,6 +64,33 @@ def test_measure_local_builder_runner():
assert mress[0].error_no == 0


def test_measure_local_builder_rpc_runner():
dag, s0 = get_tiled_matmul()

tgt = tvm.target.create("llvm")
task = ansor.SearchTask(dag, "test", tgt)

minp = ansor.MeasureInput(task, s0)
local_builder = ansor.LocalBuilder()
host = '0.0.0.0'
tracker = Tracker(host, port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server(host, port=tracker.port, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
rpc_runner = ansor.RPCRunner(device_key, host, tracker.port)

bress = local_builder.build([minp])
assert bress[0].error_no == 0
mress = rpc_runner.run([minp], bress)
assert mress[0].error_no == 0

tracker.terminate()
server.terminate()


if __name__ == "__main__":
test_serialization()
test_measure_local_builder_runner()
test_measure_local_builder_rpc_runner()
61 changes: 55 additions & 6 deletions tests/python/unittest/test_ansor_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,28 @@

import tvm
from tvm import ansor
from tvm.rpc.tracker import Tracker
from tvm.rpc.server import Server

from test_ansor_common import matmul_nkkm

def test_search_basic():
print("Test schedule search with the default search policy")
def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'):
print("Test %s schedule search with the default search policy" % (target))

N = 128
A, B, C = matmul_nkkm(N, N, N)
dag = ansor.ComputeDAG([A, B, C])
tgt = tvm.target.create("llvm")
tgt = tvm.target.create(target)
task = ansor.SearchTask(dag, "test", tgt)

seed = 944563397
random.seed(seed)

with tempfile.NamedTemporaryFile() as fp:
log_file = fp.name

cost_model = ansor.RandomModel()
search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed)
tune_option = ansor.TuneOption(n_trials=2,
tune_option = ansor.TuneOption(n_trials=2, runner=runner,
callbacks=[ansor.LogToFile(log_file)])
state = ansor.auto_schedule(task, search_policy,
tune_option=tune_option)
Expand All @@ -60,7 +61,7 @@ def test_search_basic():
print(tvm.lower(sch, args, simple_mode=True))
mod = tvm.build(sch, args, tgt)

ctx = tvm.context("llvm", 0)
ctx = tvm.context(target, 0)
a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx)
Expand All @@ -75,7 +76,55 @@ def test_search_basic():
s0 = dag.infer_bound_from_state(state)
s1 = dag.infer_bound_from_state(inp.state)
assert s0 == s1
print()


def test_search_basic():
search_common(seed=944563397)


def test_search_opencl():
if tvm.context("opencl", 0).exist:
host = '0.0.0.0'
tracker = Tracker(host, port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server(host, port=tracker.port, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
rpc_runner = ansor.RPCRunner(device_key, host, tracker.port)

search_common("opencl", 380344973, rpc_runner)

tracker.terminate()
server.terminate()
else:
print("OpenCL device not found, skip this test.")


def test_search_cuda():
ctx = tvm.context("cuda", 0)
if ctx.exist:
cuda_arch = "sm_" + "".join(ctx.compute_version.split('.'))
tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch)
host = '0.0.0.0'
tracker = Tracker(host, port=9000, port_end=10000, silent=True)
device_key = '$local$device$%d' % tracker.port
server = Server(host, port=tracker.port, port_end=10000,
key=device_key,
use_popen=True, silent=True,
tracker_addr=(tracker.host, tracker.port))
rpc_runner = ansor.RPCRunner(device_key, host, tracker.port)

search_common("cuda", 903667810, rpc_runner)

tracker.terminate()
server.terminate()
else:
print("CUDA device not found, skip this test.")


if __name__ == "__main__":
test_search_basic()
test_search_opencl()
test_search_cuda()

0 comments on commit f367d15

Please sign in to comment.