Skip to content

Commit

Permalink
Use real op
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Apr 23, 2024
1 parent 6a2f212 commit 0feccb8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
21 changes: 21 additions & 0 deletions shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import cast

from ..support.ir_imports import (
Operation,
RankedTensorType,
StringAttr,
Value,
Expand Down Expand Up @@ -60,3 +61,23 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
key = cast(AttrArg, ksel.arg_descs[0])
_emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]])
kb.yield_results(kb.arg_bindings[1])


@CustomOp.register(library=IREE_LIBRARY)
class _test_add(CustomOp):
signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(t1_desc.t.shape, t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t1, t2 = kb.arg_bindings
result = Operation.create(
"tosa.add", results=[t1.type], operands=[t1, t2]
).result
kb.yield_results(result)
4 changes: 3 additions & 1 deletion shark_turbine/runtime/op_reg/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ def compile_standalone_kernel(
with kb.ip, Location.unknown():
ksel.op.generate(ksel, kb)
kb.module_op.verify()
# DO NOT SUBMIT: https://github.com/iree-org/iree/issues/17132
enable_debug_info = False
module_asm = kb.module_op.get_asm(
binary=True, enable_debug_info=True, print_generic_op_form=False
binary=True, enable_debug_info=enable_debug_info, print_generic_op_form=False
)
generation_time = default_timer() - start

Expand Down
4 changes: 2 additions & 2 deletions tests/runtime/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ def testFromTorchDevice(self):
print(device.dump_device_info())

def testJit(self):
t = torch.tensor([1, 2, 3, 4, 5]).to("cuda:0")
from shark_turbine.ops import iree as iree_ops

iree_ops.trace_tensor("FOO", t)
t = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]).to("cuda:0")
print(iree_ops._test_add(t, t))


if __name__ == "__main__":
Expand Down

0 comments on commit 0feccb8

Please sign in to comment.