Skip to content

Commit

Permalink
[TKW] Update IR interpreter (#182)
Browse files Browse the repository at this point in the history
* Add `arith.andi`, `arith.cmpi`, `vector.maskedload`, `vector.gather`,
`vector.contant_mask`, `vector.insertelement`, `vectot.splat`, support
non-splatted contants.
* Add `interpret_ndrange` helper

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Oct 4, 2024
1 parent 207efd9 commit 64b7d27
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 31 deletions.
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ ignore_errors = True
[mypy-shark_turbine.kernel.*]
ignore_errors = True

# TODO: Some pytorch errors.
[mypy-shark_turbine.tools.interpreter]
ignore_errors = True

# Ignore all typing errors in tests/tools (these depend on TK).
[mypy-tests.tools.*]
ignore_errors = True
140 changes: 109 additions & 31 deletions shark_turbine/tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,33 @@
import re
from typing import Callable
from collections import namedtuple
import numpy as np

logger = get_logger("turbine.wave.interpreter")


from ..kernel.compiler.ir import (
amdgpu_d,
builtin_d,
Context,
F16Type,
F32Type,
IndexType,
Value,
VectorType,
IntegerAttr,
IntegerType,
Module,
Operation,
Value,
VectorType,
amdgpu_d,
arith_d,
builtin_d,
flow_d,
func_d,
gpu_d,
llvm_d,
scf_d,
vector_d,
memref_d,
IntegerAttr,
IndexType,
arith_d,
scf_d,
stream_d,
F32Type,
F16Type,
vector_d,
)


Expand All @@ -53,17 +54,12 @@ def get_dtype(self, dtype):
return torch.float32
if type(dtype) == F16Type:
return torch.float16
if type(dtype) == IndexType:
return torch.int64
if dtype == IntegerType.get_signless(1):
return torch.bool
raise NotImplementedError(f"Unsupported dtype: {dtype}")

def create_tensor(self, shape: list[int], dtype, value) -> torch.Tensor:
"""
Creates a constant tensor with the given shape, dtype and value.
The tensor is filled with ones.
"""
if type(dtype) == F32Type or type(dtype) == F16Type:
value = float(value)
return torch.ones(*shape, dtype=self.get_dtype(dtype)) * value

def callback(self, op: Operation) -> None:
if (
op.operation.parent.name == "func.func"
Expand All @@ -80,11 +76,13 @@ def callback(self, op: Operation) -> None:
elif vtype == VectorType:
shape = op.value.type.shape
dtype = op.value.type.element_type
value = self.create_tensor(
shape,
dtype,
op.attributes["value"].get_splat_value(),
)
val = op.attributes["value"]
dtype = self.get_dtype(dtype)
if val.is_splat:
val = val.get_splat_value().value
value = torch.full(shape, val, dtype=dtype)
else:
value = torch.from_numpy(np.array(val)).type(dtype=dtype)
else:
raise NotImplementedError(f"Unsupported constant type: {vtype}")
case arith_d.MulIOp:
Expand Down Expand Up @@ -112,6 +110,21 @@ def callback(self, op: Operation) -> None:
self.symbol_table[op.operands[0]]
// self.symbol_table[op.operands[1]]
)
case arith_d.AndIOp:
value = (
self.symbol_table[op.operands[0]]
& self.symbol_table[op.operands[1]]
)
case arith_d.CmpIOp:
lhs = self.symbol_table[op.lhs]
rhs = self.symbol_table[op.rhs]
pred = int(op.predicate)
if pred == int(arith_d.CmpIPredicate.slt):
value = lhs < rhs
else:
raise NotImplementedError(
f"Unsupported predicate: {op.predicate}"
)
case amdgpu_d.LDSBarrierOp:
return
case amdgpu_d.MFMAOp:
Expand All @@ -136,11 +149,10 @@ def callback(self, op: Operation) -> None:
)
# Row-major load
offset = [0 for _ in range(len(load_indices))]
offset[-1] += 1
for i in range(*result_shape):
value[i] = memref[
*[int(x) + y for x, y in zip(load_indices, offset)]
]
ind = [int(x) + y for x, y in zip(load_indices, offset)]
value[i] = memref[*ind]
offset[-1] += 1
case vector_d.ExtractStridedSliceOp:
vector = self.symbol_table[op.vector]
value = vector[[int(x) for x in op.offsets]]
Expand All @@ -154,11 +166,69 @@ def callback(self, op: Operation) -> None:
result_shape = vector.shape
# Row-major store
offset = [0 for _ in range(len(store_indices))]
offset[-1] += 1
for i in range(*result_shape):
memref[
*[int(x) + y for x, y in zip(store_indices, offset)]
] = vector[i]
offset[-1] += 1
case vector_d.MaskedStoreOp:
store_indices = []
for index in op.indices:
store_indices.append(self.symbol_table[index])
vector = self.symbol_table[op.valueToStore]
memref = self.symbol_table[op.base]
mask = self.symbol_table[op.mask]
result_type = vector.type
result_shape = vector.shape
# Row-major store
offset = [0 for _ in range(len(store_indices))]
for i in range(*result_shape):
if mask[i]:
ind = [int(x) + y for x, y in zip(store_indices, offset)]
memref[*ind] = vector[i]

offset[-1] += 1
case vector_d.ConstantMaskOp:
shape = op.result.type.shape
value = torch.ones(shape, dtype=torch.bool)
case vector_d.GatherOp:
load_indices = []
for index in op.indices:
load_indices.append(self.symbol_table[index])
logger.debug("Gather indices:", load_indices)
memref = self.symbol_table[op.base]
mask = self.symbol_table[op.mask]
index_vec = self.symbol_table[op.index_vec]
pass_thru = self.symbol_table[op.pass_thru]
result_type = op.result.type
result_shape = result_type.shape
result_dtype = result_type.element_type
value = torch.zeros(
*result_shape, dtype=self.get_dtype(result_dtype)
)
# Row-major load
offset = [0 for _ in range(len(load_indices))]
for i in range(*result_shape):
if mask[i]:
off = [
slice(int(x) + y, None)
for x, y in zip(load_indices, offset)
]
m = memref[off].flatten()
value[i] = m[index_vec[i]]
else:
value[i] = pass_thru[i]
case vector_d.InsertElementOp:
source = self.symbol_table[op.source]
value = self.symbol_table[op.dest].clone()
position = self.symbol_table[op.position]
value[int(position[0])] = source
case vector_d.SplatOp:
mtype = op.result.type
shape = mtype.shape
dtype = mtype.element_type
input = self.symbol_table[op.input][0]
value = torch.full(shape, input, dtype=self.get_dtype(dtype))
case stream_d.DispatchWorkgroupIDOp:
index = int(op.attributes["dimension"])
value = self.workgroup_ids[index]
Expand Down Expand Up @@ -214,7 +284,7 @@ def callback(self, op: Operation) -> None:
case _:
raise NotImplementedError(f"Unsupported operation: {op}")

if type(op) != vector_d.StoreOp:
if type(op) not in (vector_d.StoreOp, vector_d.MaskedStoreOp):
self.symbol_table[op.result] = value

def walk_operations(self, operation: Operation, callback: Callable) -> None:
Expand All @@ -237,6 +307,14 @@ def interpret(self, asm: str) -> None:
operation = module.operation
self.walk_operations(operation, self.callback)

@staticmethod
def interpret_ndrange(
asm: str, workgroup_count: list[int], workgroup_size: list[int]
):
for wg in np.ndindex(*workgroup_count):
for t in np.ndindex(*workgroup_size):
Interpreter([*wg], [*t]).interpret(asm)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MLIR Interpreter")
Expand Down

0 comments on commit 64b7d27

Please sign in to comment.