diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index a4abe8c31f56..dec0e9842a37 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -47,10 +47,10 @@ def add_run_parser(subparsers): parser.set_defaults(func=drive_run) # TODO --device needs to be extended and tested to support other targets, - # like 'cl', 'webgpu', etc (@leandron) + # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "gpu"], + choices=["cpu", "gpu", "cl"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -361,7 +361,13 @@ def run_module( # TODO expand to other supported devices, as listed in tvm.rpc.client (@leandron) logger.debug("device is %s", device) - ctx = session.cpu() if device == "cpu" else session.gpu() + if device == "gpu": + ctx = session.gpu() + elif device == "cl": + ctx = session.cl() + else: + assert device == "cpu" + ctx = session.cpu() if profile: logger.debug("creating runtime with profiling enabled")