Skip to content

Commit

Permalink
[TKW] Implement broadcastOp class, lowering and insertion
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Sep 27, 2024
1 parent 192a786 commit ab09ba6
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 2 deletions.
72 changes: 72 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,78 @@ def repeat(
# CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]]


@run_test
def test_broadcast_add():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
lhs = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
rhs = tkw.read(b, elements_per_thread=1)
res = lhs + rhs
tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 128)
a = torch.ones(shape, dtype=torch.float16)
b = torch.ones(shape[0], dtype=torch.float16)
c = torch.zeros(shape, dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 2,
BLOCK_N: 128,
LOAD_ELEMS_PER_THREAD: 2,
STORE_ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=False,
run_config=config,
):
print(test(a, b, c).module_op)
# CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %{{.+}}: !stream.binding)
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index

# Slicing LHS
# CHECK: %[[LHS:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<256x128xf16
# CHECK: %[[LHS_0:.+]] = vector.load %[[LHS]][%[[X_SLICE_0:.+]], %[[Y_SLICE:.+]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16>
# CHECK: %[[X_SLICE_1:.+]] = arith.addi %[[X_SLICE_0]], %c1 : index
# CHECK: %[[LHS_1:.+]] = vector.load %[[LHS]][%[[X_SLICE_1]], %[[Y_SLICE]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16>

# Slicing RHS
# CHECK: %[[RHS:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<256xf16
# CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>
# CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>

# Broadcasting RHS
# CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16>
# CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16>

# Adding
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16>


@run_test
def test_binary_lowerings():
constraints: list[tkw.Constraint] = [
Expand Down
39 changes: 38 additions & 1 deletion shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def maximum(lhs: "Register", rhs: "Register") -> "Register":
...


def broadcast(
arg: "Register", target_shape: Optional[IndexExpr | int] = None
) -> "Register":
...


def sum(
src: "Register",
acc: Optional["Register"] = None,
Expand Down Expand Up @@ -890,7 +896,10 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:

@property
def type(self) -> list[Memory | Register]:
return [get_custom(x).type for x in self.init_args]
res_types = [get_custom(x).type for x in self.init_args]
if len(res_types) == 1:
res_types = res_types[0]
return res_types

def outputs(self, graph: fx.Graph) -> list[fx.Node]:
for node in graph.nodes:
Expand Down Expand Up @@ -998,6 +1007,34 @@ def type(self) -> "Register":
return get_custom(self.register_).type


@define_op("broadcast")
@dataclass
class Broadcast(CustomOp, ABC):
"""
Represents a Broadcast operation.
arg: Source tensor/value to broadcast
target_shape: symbolic target broadcast shape.
"""

arg: fx.Node
target_type: Sequence[IndexSymbol] = None

@property
def target_shape(self):
return self.target_type.symbolic_shape

@property
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape

@property
def type(self) -> Memory:
src_dtype = get_custom(self.arg).type.dtype
dst_type = Register[*self.target_shape, src_dtype]
return dst_type


@define_interface_op("max")
@define_interface_op("sum")
@dataclass
Expand Down
76 changes: 75 additions & 1 deletion shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from shark_turbine.kernel.lang.global_symbols import *
from ..ops.wave_ops import (
write,
broadcast,
register,
mma,
shuffle,
Expand Down Expand Up @@ -77,7 +78,7 @@
WorkgroupConstraint,
TilingConstraint,
)
from .utils import subs_idxc
from .utils import subs_idxc, get_hardware_vector_map

# Indexing imports.
from .._support.indexing import IndexingContext, IndexExpr, IndexSequence
Expand Down Expand Up @@ -995,6 +996,79 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node):
emitter.bind_node_proxy(node, IRProxyValue(element))


###############################################################################
# Reshape ops
###############################################################################


@handle_op(broadcast)
def handle_broadcast(emitter: WaveEmitter, node: fx.Node):
try:
register, target_type = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
custom_arg = get_custom(register)
src_shape = get_custom(register).type.symbolic_shape
target_shape = target_type.symbolic_shape
hw_constraints = [
constraint
for constraint in emitter.constraints
if isinstance(constraint, HardwareConstraint)
]
if len(hw_constraints) != 1:
raise NotImplementedError(
"Only support single HW Constraint for lowering broadcast."
)
hw_constraint = hw_constraints[0]
if not hasattr(hw_constraint, "vector_shapes"):
raise NotImplementedError(
"Only support broadcast for kernel with non-MMA constraints."
)

# As of writing of this pass, TKW expects only 1D vectors,
# or only 1 non-unit dim vector_shape. So we are not handling
# broadcast of N-d case.
broadcast_dims = set(target_shape) - set(src_shape)
if len(broadcast_dims) != 1:
raise NotImplementedError("NYI: Support for multiple broadcasting dims.")
broadcast_dim = next(iter(broadcast_dims))

# Extract vector shape of the broadcasted dimension from vector_shape.
vector_map = hw_constraint.vector_shapes
bcast_dim_wave_size = vector_map[broadcast_dim]
if bcast_dim_wave_size % hw_constraint.threads_per_wave != 0:
raise NotImplementedError(
"Only handle when vector_shape is factor of warp size."
)
bcast_dim_lane_dim_size = bcast_dim_wave_size // hw_constraint.threads_per_wave

# Ensure vector_shape for non-broadcast dim in source is unit dim.
src_dims = get_custom(register).indexing_dims
if any(vector_map.get(src_dim, 0) != 1 for src_dim in src_dims):
raise NotImplementedError(
"Cannot handle non broadcasted dim being non unit-dim in vector_shape constraints."
)

# Check MLIR shape
vector_src = cast_vector(emitter, register)
vector_type = vector_src.type
# Only support broadcasting vector<1xdtype> for now.
if not VectorType.isinstance(vector_type):
raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.")
assert vector_type.rank == 1
assert vector_type.shape[0] == 1

# Extract and Splat
# If by chance broadcast size matches current size, we can return src.
if bcast_dim_lane_dim_size == vector_type.shape[0]:
emitter.bind_node_proxy(node, IRProxyValue(vector_src))

result_type = VectorType.get([bcast_dim_lane_dim_size], vector_type.element_type)
element = vector_d.extract(vector_src, static_position=[0], dynamic_position=[])
splat = vector_d.splat(result_type, element)
emitter.bind_node_proxy(node, IRProxyValue(splat))


###############################################################################
# Miscellanous ops
###############################################################################
Expand Down
49 changes: 49 additions & 0 deletions shark_turbine/kernel/wave/insert_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ...support.logging import get_logger
from shark_turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
from ..ops.wave_ops import *
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.hoisting")


def get_allocs(graph: fx.Graph) -> list[CustomOp]:
return [
custom_node
for node in graph.nodes
if isinstance((custom_node := get_custom(node)), Allocate)
]


def insert_broadcast(trace: CapturedTrace):
"""Insert broadcasts to binary ops operands that requires it."""

def is_binary_op(node: fx.Node) -> bool:
return isinstance(get_custom(node), BinaryPyOp)

binary_nodes = trace.walk(is_binary_op)

for node in binary_nodes:
custom_node = get_custom(node)
lhs = get_custom(custom_node.lhs)
rhs = get_custom(custom_node.rhs)
lhs_dim_set = set(lhs.type.symbolic_shape)
rhs_dim_set = set(rhs.type.symbolic_shape)
if lhs_dim_set == rhs_dim_set:
continue
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.")
target_shape = lhs.type if lhs_dim_set > rhs_dim_set else rhs.type
broadcast_idx, broadcast_src = (
(1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs)
)
broadcast = Broadcast(broadcast_src.fx_node, target_shape)
with custom_node.graph.inserting_before(custom_node.fx_node):
broadcast.add_to_graph(custom_node.graph)
custom_node.update_arg(broadcast_idx, broadcast.fx_node)
4 changes: 4 additions & 0 deletions shark_turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
safe_subs,
remove_chained_getresult,
)
from .insert_broadcast import insert_broadcast
from .minimize_global_loads import minimize_global_loads
from .decompose_reduce_ops import decompose_reduce_ops
from .barriers import add_shared_memory_barriers
Expand Down Expand Up @@ -207,6 +208,9 @@ def _trace_and_get_kernel_signature(
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)

# Insert broadcasts on implicit broadcasted binaryOps.
insert_broadcast(graph)

# Expansion
expand_graph(graph, self.constraints)

Expand Down

0 comments on commit ab09ba6

Please sign in to comment.