Skip to content

Commit

Permalink
Add a Launchable class for directly running IREE modules on PyTorch d…
Browse files Browse the repository at this point in the history
…evices. (#17)

This uses much of the plumbing of the custom ops workflow but is geared
for larger scale integrations. Whereas the custom op system is optimized
for things more "kernel" sized, focusing on specialization and JIT
compilation of variants, this workflow is geared towards integrating
entire programs (either from VMFB or JIT compiled for the device in use
on the fly) as a Torch callable.

Usage for bring-your-own-VMFB:

```
launch = Launchable.from_vm_module(lambda device: VmModule.mmap(device.vm_instance, "foo.vmfb"))
result = launch(tensor1, tensor2)
```

Usage for JIT compiling:

```
launch = Launchable.jit_compile(MLIR_ASM)
result = launch(tensor1, tensor2)
```

In the first case, it is the caller's responsibility to produce a VMFB
that is valid for the given device. In the JIT case, appropriate
compiler flags and targeting information will be set based on the type
of the device the input tensors are located on (i.e. if ROCM/CUDA, this
will also properly differentiate between heterogenous devices on the
system and compile a binary for each distinct target).

Limitations:

* The underlying mechanism currently uses the default stream for
synchronization. It is TBI to plumb through more explicit support.
* As a consequence of the above, we also are syncing the device after
launch.
* We are waiting for upstream PyTorch patches to land to get UUIDs from
torch devices. Without this, enumeration order has to match, which is
not guaranteed.

Includes workarounds for:

* iree-org/iree#17402
* iree-org/iree#17403

---------

Signed-off-by: Stella Laurenzo <stellaraccident@gmail.com>
  • Loading branch information
stellaraccident authored May 15, 2024
1 parent 8fa1166 commit c4db712
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 2 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Build/test requirements.
Jinja2==3.1.3
numpy==1.26.3
parameterized==0.9.0
pytest==8.0.0
pytest-xdist==3.5.0
mypy==1.8.0
Expand Down
4 changes: 4 additions & 0 deletions shark_turbine/aot/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ class ParameterArchiveBuilder:
def __init__(self):
self._index = ParameterIndex()

@property
def index(self) -> ParameterIndex:
return self._index

def save(self, file_path: Union[str, Path]):
"""Saves the archive."""
self._index.create_archive_file(str(file_path))
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .device import *
from .launch import *
from . import op_reg
22 changes: 20 additions & 2 deletions shark_turbine/runtime/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Device:
"_tx_timepoint",
"_fence_capacity",
"compile_target_flags",
"driver_id",
"export_torch_tensor",
"import_torch_tensor",
"instance_cache_key",
Expand All @@ -157,6 +158,9 @@ class Device:
# a meta tensor that describes it.
export_torch_tensor: Callable[[HalBufferView, torch.Tensor], torch.Tensor]

# Unique name of the IREE runtime driver associated with this device.
driver_id: str

# Cache key that uniquely identifies this device.
instance_cache_key: str

Expand Down Expand Up @@ -222,6 +226,7 @@ def _initialize(self):
colon_pos = driver_id.find(":")
if colon_pos >= 0:
driver_id = driver_id[0:colon_pos]
self.driver_id = driver_id
try:
import_fn = TORCH_TENSOR_IMPORTERS[driver_id]
export_fn = TORCH_TENSOR_EXPORTERS[driver_id]
Expand All @@ -237,7 +242,10 @@ def _initialize(self):
# TODO: The type cache key should actually be based on the driver id
# and device characteristics hash.
self.instance_cache_key = repr(d)
self.type_cache_key = driver_id
self._recompute_target_keys()

def _recompute_target_keys(self):
self.type_cache_key = f"{self.driver_id}:{';'.join(self.compile_target_flags)}"

@property
def hal_device(self) -> HalDevice:
Expand Down Expand Up @@ -477,17 +485,27 @@ def _create_cuda_device(torch_device: torch.device, props) -> Optional[Device]:
device.compile_target_flags = device.compile_target_flags + (
f"--iree-hal-cuda-llvm-target-arch=sm_{props.major}{props.minor}",
)
device._recompute_target_keys()
return device


def _create_hip_device(torch_device: torch.device, props) -> Optional[Device]:
# Note that the dlpack device type code for ROCM is 10.
device = _create_cuda_like_device(torch_device, props, "hip", 10)
# The gcnArchName comes back like gfx90a:sramecc+:xnack- for a fully
# specified target. However the IREE target-chip flag only expects the
# prefix. See: https://github.com/iree-org/iree/issues/17402
# This should be changed to tunnel through target information unmolested.
gcn_arch_name: str = props.gcnArchName
colon_pos = gcn_arch_name.find(":")
if colon_pos >= 0:
gcn_arch_name = gcn_arch_name[0:colon_pos]
if device:
gcn_arch_name = props.gcnArchName
gcn_arch_name = gcn_arch_name
device.compile_target_flags = device.compile_target_flags + (
f"--iree-rocm-target-chip={gcn_arch_name}",
)
device._recompute_target_keys()
return device


Expand Down
248 changes: 248 additions & 0 deletions shark_turbine/runtime/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Any, Callable, Optional, Sequence, Tuple

import torch
from torch import Tensor

from iree.compiler.api import (
Session,
Source,
Output,
)

from iree.runtime import (
create_io_parameters_module,
HalBufferView,
HalElementType,
ParameterProvider,
VmContext,
VmFunction,
VmModule,
VmRef,
VmVariantList,
)

from ..support.logging import runtime_logger as logger

from .device import (
get_device_from_torch,
Device,
)

__all__ = [
"Launchable",
]

_TargetBinary = Tuple[VmContext, VmFunction]
_Loader = Callable[[Device], _TargetBinary]


class Launchable:
"""Facilities for launching a compiled program (VMFB) on an attached device.
Like the eager custom-op executor, this follows the usual PyTorch rules
whereby the device that input tensors reside on dictates where the launch
happens. Unlike that flow, this does not include any notion of jitting
or caching. It also has APIs for using parameters, etc.
You must manage all compilation/target settings yourself and you merely
assert that a given binary is appropriate for launch on a device type.
This has various limitations.
"""

def __init__(self, loader: Optional[_Loader]):
self._loader = loader
# Map of Device.type_cache_key -> _TargetBinary for a resolved binary.
self._target_binaries: dict[str, _TargetBinary] = {}

@staticmethod
def jit_compile(
source: Any,
*,
parameter_providers: Sequence[ParameterProvider] = (),
entry_point: str = "main",
) -> "Launchable":
return Launchable.from_vm_module(
_jit_callback(source),
parameter_providers=parameter_providers,
entry_point=entry_point,
)

@staticmethod
def from_vm_module(
vm_module_callback: Callable[[Device], VmModule],
*,
parameter_providers: Sequence[ParameterProvider] = (),
entry_point: str = "main",
) -> "Launchable":
def loader(device: Device) -> _TargetBinary:
vm_instance = device.vm_instance
modules = [device.create_hal_module()]
if parameter_providers:
modules.append(
create_io_parameters_module(vm_instance, *parameter_providers)
)
main_module = vm_module_callback(device)
modules.append(main_module)
vm_context = VmContext(vm_instance, modules)
main_function = main_module.lookup_function(entry_point)
return vm_context, main_function

return Launchable(loader)

def preload(self, device: torch.device):
"""Pre-loads (or JIT compiles) for the given torch.device."""
turbine_device = get_device_from_torch(device)
self._resolve_target_binary(turbine_device)

def _resolve_target_binary(self, turbine_device: Device) -> _TargetBinary:
device_key = turbine_device.type_cache_key
existing = self._target_binaries.get(device_key)
if existing is not None:
logger.debug("Launching cached binary for %s", device_key)
return existing

# Try the user loader.
loader = self._loader
if loader is not None:
loaded = loader(turbine_device)
if loaded is not None:
logger.debug("Cached new binary for %s", device_key)
self._target_binaries[device_key] = loaded
return loaded
raise NotImplementedError(
f"Could not load a target binary for device {turbine_device}"
)

def __call__(self, *args, device: Optional[torch.device] = None):
turbine_device: Optional[Device] = (
None if device is None else get_device_from_torch(device)
)
arg_list = VmVariantList(len(args))
# Scan args for tensors and infer device.
for arg in args:
if isinstance(arg, Tensor):
# For pre-compiled launchables, there is no support for anything
# but contiguous layouts.
if not arg.is_contiguous():
arg = arg.contiguous()
tensor_device = arg.device
if device is None:
device = tensor_device
else:
if tensor_device != device:
raise RuntimeError(
f"Cannot launch with tensors from multiple devices: "
f"{tensor_device} vs {device}"
)
if turbine_device is None:
turbine_device = get_device_from_torch(tensor_device)
# Since we know we are on the same device, we can use the unsafe
# import_torch_tensor.
arg_list.push_ref(turbine_device.import_torch_tensor(arg))
elif isinstance(arg, int):
arg_list.push_int(arg)
elif isinstance(arg, float):
arg_list.push_float(arg)

# Having at least one tensor arg is a pre-requisite for normal operation
if device is None or turbine_device is None:
raise RuntimeError(
f"Cannot invoke Launchable {self} without any Tensor args or an explicit device="
)

vm_context, vm_function = self._resolve_target_binary(turbine_device)
ret_list = VmVariantList(1)
vm_context.invoke(vm_function, arg_list, ret_list)
torch_results = []
for i in range(len(ret_list)):
result = ret_list.get_variant(i)
if isinstance(result, VmRef):
buffer_view = result.deref(HalBufferView, True)
if buffer_view is not None:
torch_results.append(
_export_torch_tensor(buffer_view, turbine_device)
)

arity = len(torch_results)
if arity == 1:
return torch_results[0]
elif arity == 0:
return None
else:
return torch_results


def _jit_callback(program_source: Any) -> Callable[[Device], VmModule]:
session = Session()
if isinstance(program_source, Source):
...
elif isinstance(program_source, str):
source = Source.wrap_buffer(session, program_source.encode())
else:
source = Source.wrap_buffer(session, program_source)

def callback(device: Device):
session.set_flags(*device.compile_target_flags)
inv = session.invocation()
output = Output.open_membuffer()
inv.enable_console_diagnostics()
inv.parse_source(source)
if not inv.execute():
# TODO: Capture diagnostics and report.
raise RuntimeError(f"JIT compilation failed. See diagnostics.")
inv.output_vm_bytecode(output)
mapped_memory = output.map_memory()
vm_instance = device.vm_instance
# TODO: VmModule.wrap_buffer would be better here, but it is still
# unreliable capturing mapped memory from the compiler.
# See: https://github.com/iree-org/iree/issues/17403
return VmModule.copy_buffer(vm_instance, mapped_memory)

return callback


def _export_torch_tensor(bv: HalBufferView, turbine_device: Device) -> Tensor:
# Usually in the custom op flow, we have strong metadata about the results.
# But since the whole purpose of this is for interfacing a blackbox, we
# just infer from IREE type -> torch type. This may be lossy on dtypes
# that are not an exact match, and the user is expected to bitcast.
dtype = _INFERRED_ELEMENT_TYPE_TO_DTYPE[bv.element_type]
if dtype is None:
raise NotImplementedError(
f"HalBufferView.element_type({bv.element_type}) has no mapping to dtype"
)
meta_tensor = torch.empty(bv.shape, dtype=dtype, device="meta")
return turbine_device.export_torch_tensor(bv, meta_tensor)


# This is a relatively special purpose mapping. We usually don't go this
# way because it is lossy: IREE's types are "fundamental" and lacking
# signed/unsigned distinctions at this layer, so we do the best we can.
# If this becomes less special purpose, move it to conversions.py
_INFERRED_ELEMENT_TYPE_TO_DTYPE: dict[HalElementType, torch.dtype] = {
HalElementType.BFLOAT_16: torch.bfloat16,
HalElementType.BOOL_8: torch.bool,
HalElementType.COMPLEX_64: torch.complex64,
HalElementType.COMPLEX_128: torch.complex128,
HalElementType.FLOAT_16: torch.float16,
HalElementType.FLOAT_32: torch.float32,
HalElementType.FLOAT_64: torch.float64,
HalElementType.INT_8: torch.int8,
HalElementType.INT_16: torch.int16,
HalElementType.INT_32: torch.int32,
HalElementType.INT_64: torch.int64,
HalElementType.SINT_8: torch.int8,
HalElementType.SINT_16: torch.int16,
HalElementType.SINT_32: torch.int32,
HalElementType.SINT_64: torch.int64,
HalElementType.UINT_8: torch.uint8,
HalElementType.UINT_16: torch.uint16,
HalElementType.UINT_32: torch.uint32,
HalElementType.UINT_64: torch.uint64,
}
3 changes: 3 additions & 0 deletions shark_turbine/runtime/op_reg/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def compile_standalone_kernel(

# Load.
vm_instance = device.vm_instance
# TODO: VmModule.wrap_buffer would be better here, but it is still
# unreliable capturing mapped memory from the compiler.
# See: https://github.com/iree-org/iree/issues/17403
vm_module = VmModule.copy_buffer(vm_instance, mapped_memory)
# TODO: We should be able to wrap the buffer as below but there are some
# subtle ref-counting/shutdown sequencing issues that need to be resolved.
Expand Down
Loading

0 comments on commit c4db712

Please sign in to comment.