From c4db712b85bf6ec79ac269fc412d52543806f857 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 14 May 2024 19:40:06 -0700 Subject: [PATCH] Add a Launchable class for directly running IREE modules on PyTorch devices. (#17) This uses much of the plumbing of the custom ops workflow but is geared for larger scale integrations. Whereas the custom op system is optimized for things more "kernel" sized, focusing on specialization and JIT compilation of variants, this workflow is geared towards integrating entire programs (either from VMFB or JIT compiled for the device in use on the fly) as a Torch callable. Usage for bring-your-own-VMFB: ``` launch = Launchable.from_vm_module(lambda device: VmModule.mmap(device.vm_instance, "foo.vmfb")) result = launch(tensor1, tensor2) ``` Usage for JIT compiling: ``` launch = Launchable.jit_compile(MLIR_ASM) result = launch(tensor1, tensor2) ``` In the first case, it is the caller's responsibility to produce a VMFB that is valid for the given device. In the JIT case, appropriate compiler flags and targeting information will be set based on the type of the device the input tensors are located on (i.e. if ROCM/CUDA, this will also properly differentiate between heterogenous devices on the system and compile a binary for each distinct target). Limitations: * The underlying mechanism currently uses the default stream for synchronization. It is TBI to plumb through more explicit support. * As a consequence of the above, we also are syncing the device after launch. * We are waiting for upstream PyTorch patches to land to get UUIDs from torch devices. Without this, enumeration order has to match, which is not guaranteed. Includes workarounds for: * https://github.com/iree-org/iree/issues/17402 * https://github.com/iree-org/iree/issues/17403 --------- Signed-off-by: Stella Laurenzo --- requirements.txt | 1 + shark_turbine/aot/params.py | 4 + shark_turbine/runtime/__init__.py | 1 + shark_turbine/runtime/device.py | 22 +- shark_turbine/runtime/launch.py | 248 +++++++++++++++++++++++ shark_turbine/runtime/op_reg/compiler.py | 3 + tests/runtime/launch_test.py | 88 ++++++++ 7 files changed, 365 insertions(+), 2 deletions(-) create mode 100644 shark_turbine/runtime/launch.py create mode 100644 tests/runtime/launch_test.py diff --git a/requirements.txt b/requirements.txt index bc4eaf0f..d37678f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Build/test requirements. Jinja2==3.1.3 numpy==1.26.3 +parameterized==0.9.0 pytest==8.0.0 pytest-xdist==3.5.0 mypy==1.8.0 diff --git a/shark_turbine/aot/params.py b/shark_turbine/aot/params.py index 54858648..63bcc1b0 100644 --- a/shark_turbine/aot/params.py +++ b/shark_turbine/aot/params.py @@ -256,6 +256,10 @@ class ParameterArchiveBuilder: def __init__(self): self._index = ParameterIndex() + @property + def index(self) -> ParameterIndex: + return self._index + def save(self, file_path: Union[str, Path]): """Saves the archive.""" self._index.create_archive_file(str(file_path)) diff --git a/shark_turbine/runtime/__init__.py b/shark_turbine/runtime/__init__.py index 29434c26..f1db6fe5 100644 --- a/shark_turbine/runtime/__init__.py +++ b/shark_turbine/runtime/__init__.py @@ -5,4 +5,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .device import * +from .launch import * from . import op_reg diff --git a/shark_turbine/runtime/device.py b/shark_turbine/runtime/device.py index 40236601..d34f49a8 100644 --- a/shark_turbine/runtime/device.py +++ b/shark_turbine/runtime/device.py @@ -139,6 +139,7 @@ class Device: "_tx_timepoint", "_fence_capacity", "compile_target_flags", + "driver_id", "export_torch_tensor", "import_torch_tensor", "instance_cache_key", @@ -157,6 +158,9 @@ class Device: # a meta tensor that describes it. export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor] + # Unique name of the IREE runtime driver associated with this device. + driver_id: str + # Cache key that uniquely identifies this device. instance_cache_key: str @@ -222,6 +226,7 @@ def _initialize(self): colon_pos = driver_id.find(":") if colon_pos >= 0: driver_id = driver_id[0:colon_pos] + self.driver_id = driver_id try: import_fn = TORCH_TENSOR_IMPORTERS[driver_id] export_fn = TORCH_TENSOR_EXPORTERS[driver_id] @@ -237,7 +242,10 @@ def _initialize(self): # TODO: The type cache key should actually be based on the driver id # and device characteristics hash. self.instance_cache_key = repr(d) - self.type_cache_key = driver_id + self._recompute_target_keys() + + def _recompute_target_keys(self): + self.type_cache_key = f"{self.driver_id}:{';'.join(self.compile_target_flags)}" @property def hal_device(self) -> HalDevice: @@ -477,17 +485,27 @@ def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]: device.compile_target_flags = device.compile_target_flags + ( f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}", ) + device._recompute_target_keys() return device def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]: # Note that the dlpack device type code for ROCM is 10. device = _create_cuda_like_device(torch_device, props, "hip", 10) + # The gcnArchName comes back like gfx90a:sramecc+:xnack- for a fully + # specified target. However the IREE target-chip flag only expects the + # prefix. See: https://github.com/iree-org/iree/issues/17402 + # This should be changed to tunnel through target information unmolested. + gcn_arch_name: str = props.gcnArchName + colon_pos = gcn_arch_name.find(":") + if colon_pos >= 0: + gcn_arch_name = gcn_arch_name[0:colon_pos] if device: - gcn_arch_name = props.gcnArchName + gcn_arch_name = gcn_arch_name device.compile_target_flags = device.compile_target_flags + ( f"--iree-rocm-target-chip={gcn_arch_name}", ) + device._recompute_target_keys() return device diff --git a/shark_turbine/runtime/launch.py b/shark_turbine/runtime/launch.py new file mode 100644 index 00000000..f4a59f6e --- /dev/null +++ b/shark_turbine/runtime/launch.py @@ -0,0 +1,248 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Callable, Optional, Sequence, Tuple + +import torch +from torch import Tensor + +from iree.compiler.api import ( + Session, + Source, + Output, +) + +from iree.runtime import ( + create_io_parameters_module, + HalBufferView, + HalElementType, + ParameterProvider, + VmContext, + VmFunction, + VmModule, + VmRef, + VmVariantList, +) + +from ..support.logging import runtime_logger as logger + +from .device import ( + get_device_from_torch, + Device, +) + +__all__ = [ + "Launchable", +] + +_TargetBinary = Tuple[VmContext, VmFunction] +_Loader = Callable[[Device], _TargetBinary] + + +class Launchable: + """Facilities for launching a compiled program (VMFB) on an attached device. + + Like the eager custom-op executor, this follows the usual PyTorch rules + whereby the device that input tensors reside on dictates where the launch + happens. Unlike that flow, this does not include any notion of jitting + or caching. It also has APIs for using parameters, etc. + + You must manage all compilation/target settings yourself and you merely + assert that a given binary is appropriate for launch on a device type. + This has various limitations. + """ + + def __init__(self, loader: Optional[_Loader]): + self._loader = loader + # Map of Device.type_cache_key -> _TargetBinary for a resolved binary. + self._target_binaries: dict[str, _TargetBinary] = {} + + @staticmethod + def jit_compile( + source: Any, + *, + parameter_providers: Sequence[ParameterProvider] = (), + entry_point: str = "main", + ) -> "Launchable": + return Launchable.from_vm_module( + _jit_callback(source), + parameter_providers=parameter_providers, + entry_point=entry_point, + ) + + @staticmethod + def from_vm_module( + vm_module_callback: Callable[[Device], VmModule], + *, + parameter_providers: Sequence[ParameterProvider] = (), + entry_point: str = "main", + ) -> "Launchable": + def loader(device: Device) -> _TargetBinary: + vm_instance = device.vm_instance + modules = [device.create_hal_module()] + if parameter_providers: + modules.append( + create_io_parameters_module(vm_instance, *parameter_providers) + ) + main_module = vm_module_callback(device) + modules.append(main_module) + vm_context = VmContext(vm_instance, modules) + main_function = main_module.lookup_function(entry_point) + return vm_context, main_function + + return Launchable(loader) + + def preload(self, device: torch.device): + """Pre-loads (or JIT compiles) for the given torch.device.""" + turbine_device = get_device_from_torch(device) + self._resolve_target_binary(turbine_device) + + def _resolve_target_binary(self, turbine_device: Device) -> _TargetBinary: + device_key = turbine_device.type_cache_key + existing = self._target_binaries.get(device_key) + if existing is not None: + logger.debug("Launching cached binary for %s", device_key) + return existing + + # Try the user loader. + loader = self._loader + if loader is not None: + loaded = loader(turbine_device) + if loaded is not None: + logger.debug("Cached new binary for %s", device_key) + self._target_binaries[device_key] = loaded + return loaded + raise NotImplementedError( + f"Could not load a target binary for device {turbine_device}" + ) + + def __call__(self, *args, device: Optional[torch.device] = None): + turbine_device: Optional[Device] = ( + None if device is None else get_device_from_torch(device) + ) + arg_list = VmVariantList(len(args)) + # Scan args for tensors and infer device. + for arg in args: + if isinstance(arg, Tensor): + # For pre-compiled launchables, there is no support for anything + # but contiguous layouts. + if not arg.is_contiguous(): + arg = arg.contiguous() + tensor_device = arg.device + if device is None: + device = tensor_device + else: + if tensor_device != device: + raise RuntimeError( + f"Cannot launch with tensors from multiple devices: " + f"{tensor_device} vs {device}" + ) + if turbine_device is None: + turbine_device = get_device_from_torch(tensor_device) + # Since we know we are on the same device, we can use the unsafe + # import_torch_tensor. + arg_list.push_ref(turbine_device.import_torch_tensor(arg)) + elif isinstance(arg, int): + arg_list.push_int(arg) + elif isinstance(arg, float): + arg_list.push_float(arg) + + # Having at least one tensor arg is a pre-requisite for normal operation + if device is None or turbine_device is None: + raise RuntimeError( + f"Cannot invoke Launchable {self} without any Tensor args or an explicit device=" + ) + + vm_context, vm_function = self._resolve_target_binary(turbine_device) + ret_list = VmVariantList(1) + vm_context.invoke(vm_function, arg_list, ret_list) + torch_results = [] + for i in range(len(ret_list)): + result = ret_list.get_variant(i) + if isinstance(result, VmRef): + buffer_view = result.deref(HalBufferView, True) + if buffer_view is not None: + torch_results.append( + _export_torch_tensor(buffer_view, turbine_device) + ) + + arity = len(torch_results) + if arity == 1: + return torch_results[0] + elif arity == 0: + return None + else: + return torch_results + + +def _jit_callback(program_source: Any) -> Callable[[Device], VmModule]: + session = Session() + if isinstance(program_source, Source): + ... + elif isinstance(program_source, str): + source = Source.wrap_buffer(session, program_source.encode()) + else: + source = Source.wrap_buffer(session, program_source) + + def callback(device: Device): + session.set_flags(*device.compile_target_flags) + inv = session.invocation() + output = Output.open_membuffer() + inv.enable_console_diagnostics() + inv.parse_source(source) + if not inv.execute(): + # TODO: Capture diagnostics and report. + raise RuntimeError(f"JIT compilation failed. See diagnostics.") + inv.output_vm_bytecode(output) + mapped_memory = output.map_memory() + vm_instance = device.vm_instance + # TODO: VmModule.wrap_buffer would be better here, but it is still + # unreliable capturing mapped memory from the compiler. + # See: https://github.com/iree-org/iree/issues/17403 + return VmModule.copy_buffer(vm_instance, mapped_memory) + + return callback + + +def _export_torch_tensor(bv: HalBufferView, turbine_device: Device) -> Tensor: + # Usually in the custom op flow, we have strong metadata about the results. + # But since the whole purpose of this is for interfacing a blackbox, we + # just infer from IREE type -> torch type. This may be lossy on dtypes + # that are not an exact match, and the user is expected to bitcast. + dtype = _INFERRED_ELEMENT_TYPE_TO_DTYPE[bv.element_type] + if dtype is None: + raise NotImplementedError( + f"HalBufferView.element_type({bv.element_type}) has no mapping to dtype" + ) + meta_tensor = torch.empty(bv.shape, dtype=dtype, device="meta") + return turbine_device.export_torch_tensor(bv, meta_tensor) + + +# This is a relatively special purpose mapping. We usually don't go this +# way because it is lossy: IREE's types are "fundamental" and lacking +# signed/unsigned distinctions at this layer, so we do the best we can. +# If this becomes less special purpose, move it to conversions.py +_INFERRED_ELEMENT_TYPE_TO_DTYPE: dict[HalElementType, torch.dtype] = { + HalElementType.BFLOAT_16: torch.bfloat16, + HalElementType.BOOL_8: torch.bool, + HalElementType.COMPLEX_64: torch.complex64, + HalElementType.COMPLEX_128: torch.complex128, + HalElementType.FLOAT_16: torch.float16, + HalElementType.FLOAT_32: torch.float32, + HalElementType.FLOAT_64: torch.float64, + HalElementType.INT_8: torch.int8, + HalElementType.INT_16: torch.int16, + HalElementType.INT_32: torch.int32, + HalElementType.INT_64: torch.int64, + HalElementType.SINT_8: torch.int8, + HalElementType.SINT_16: torch.int16, + HalElementType.SINT_32: torch.int32, + HalElementType.SINT_64: torch.int64, + HalElementType.UINT_8: torch.uint8, + HalElementType.UINT_16: torch.uint16, + HalElementType.UINT_32: torch.uint32, + HalElementType.UINT_64: torch.uint64, +} diff --git a/shark_turbine/runtime/op_reg/compiler.py b/shark_turbine/runtime/op_reg/compiler.py index bb34c017..ae5d35b4 100644 --- a/shark_turbine/runtime/op_reg/compiler.py +++ b/shark_turbine/runtime/op_reg/compiler.py @@ -130,6 +130,9 @@ def compile_standalone_kernel( # Load. vm_instance = device.vm_instance + # TODO: VmModule.wrap_buffer would be better here, but it is still + # unreliable capturing mapped memory from the compiler. + # See: https://github.com/iree-org/iree/issues/17403 vm_module = VmModule.copy_buffer(vm_instance, mapped_memory) # TODO: We should be able to wrap the buffer as below but there are some # subtle ref-counting/shutdown sequencing issues that need to be resolved. diff --git a/tests/runtime/launch_test.py b/tests/runtime/launch_test.py new file mode 100644 index 00000000..1a142161 --- /dev/null +++ b/tests/runtime/launch_test.py @@ -0,0 +1,88 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from parameterized import parameterized_class +import torch +import unittest + +from shark_turbine.aot.params import ( + ParameterArchiveBuilder, +) + +from shark_turbine.runtime import ( + Launchable, +) + +MLIR_NO_PARAMS_ASM = r""" +module @test_module { +func.func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = arith.muli %arg0, %arg1 : tensor<4xi32> + return %0 : tensor<4xi32> +} +} +""" + +MLIR_PARAMS_ASM = r""" +module @test_module { +util.global private @param = #stream.parameter.named<"param"> : tensor<4xi32> +func.func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = arith.muli %arg0, %arg1 : tensor<4xi32> + %param = util.global.load @param : tensor<4xi32> + %1 = arith.addi %0, %param : tensor<4xi32> + return %1 : tensor<4xi32> +} +} +""" + +# TODO: Move this to a common utility controlled by project wide env vars. +devices = [[torch.device("cpu")]] +if torch.cuda.is_available(): + devices.append([torch.device("cuda:0")]) + + +@parameterized_class(["device"], devices) +class LaunchableTest(unittest.TestCase): + def testLaunchJit(self): + launch = Launchable.jit_compile(MLIR_NO_PARAMS_ASM) + t1 = torch.tensor([1, 2, 3, 4], dtype=torch.int32).to(self.device) + t2 = torch.tensor([10, 20, 30, 40], dtype=torch.int32).to(self.device) + result = launch(t1, t2) + expected = torch.tensor([10, 40, 90, 160], dtype=torch.int32).to(self.device) + torch.testing.assert_close(expected, result) + + def testLaunchPreload(self): + launch = Launchable.jit_compile(MLIR_NO_PARAMS_ASM) + launch.preload(self.device) + launch._loader = None # Don't let it load anything more. + t1 = torch.tensor([1, 2, 3, 4], dtype=torch.int32).to(self.device) + t2 = torch.tensor([10, 20, 30, 40], dtype=torch.int32).to(self.device) + result = launch(t1, t2) + expected = torch.tensor([10, 40, 90, 160], dtype=torch.int32).to(self.device) + torch.testing.assert_close(expected, result) + + def testLaunchParamsWithoutParams(self): + launch = Launchable.jit_compile(MLIR_PARAMS_ASM) + with self.assertRaisesRegex( + ValueError, "required module 'io_parameters' not registered" + ): + launch.preload(self.device) + + def testLaunchParams(self): + param_archive = ParameterArchiveBuilder() + param_archive.add_tensor("param", torch.tensor([2, 4, 6, 8], dtype=torch.int32)) + provider = param_archive.index.create_provider() + + launch = Launchable.jit_compile(MLIR_PARAMS_ASM, parameter_providers=[provider]) + launch.preload(self.device) + t1 = torch.tensor([1, 2, 3, 4], dtype=torch.int32).to(self.device) + t2 = torch.tensor([10, 20, 30, 40], dtype=torch.int32).to(self.device) + result = launch(t1, t2) + expected = torch.tensor([12, 44, 96, 168], dtype=torch.int32).to(self.device) + torch.testing.assert_close(expected, result) + + +if __name__ == "__main__": + unittest.main()