diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 489604d79cf4..5a15d228803d 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -50,7 +50,7 @@ def add_run_parser(subparsers): # like 'webgpu', etc (@leandron) parser.add_argument( "--device", - choices=["cpu", "cuda", "cl", "metal", "vulkan"], + choices=["cpu", "cuda", "cl", "metal", "vulkan", "rocm"], default="cpu", help="target device to run the compiled module. Defaults to 'cpu'", ) @@ -394,6 +394,8 @@ def run_module( dev = session.metal() elif device == "vulkan": dev = session.vulkan() + elif device == "rocm": + dev = session.rocm() else: assert device == "cpu" dev = session.cpu() diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index d8199c4c93a6..a9834391ed88 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -217,6 +217,10 @@ def metal(self, dev_id=0): """Construct Metal device.""" return self.device(8, dev_id) + def rocm(self, dev_id=0): + """Construct ROCm device.""" + return self.device(10, dev_id) + def ext_dev(self, dev_id=0): """Construct extension device.""" return self.device(12, dev_id)