Skip to content

Commit

Permalink
[TKW] Implement broadcastOp class, lowering and insertion (#176)
Browse files Browse the repository at this point in the history
Motivation of this PR is to be able to codegen/lower broadcast properly.
With that in mind, we implemented these things:

1. BroadcastOp class, op and lowering, to represent and store
broadcasting information. Mostly S.T we can query target shape
information and the source operand of broadcast.
2. Treat broadcast-add as an index conflict and handle it by emitting
broadcastOp.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Oct 4, 2024
1 parent 64b7d27 commit 7617c94
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 11 deletions.
72 changes: 72 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,78 @@ def repeat(
# CHECK: scf.yield %[[ACC_MAX_0]], %[[ACC_SUM_0]], %[[ACC_MAX_1]], %[[ACC_SUM_1]]


@run_test
def test_broadcast_add():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
lhs = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
rhs = tkw.read(b, elements_per_thread=1)
res = lhs + rhs
tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 128)
a = torch.ones(shape, dtype=torch.float16)
b = torch.ones(shape[0], dtype=torch.float16)
c = torch.zeros(shape, dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 2,
BLOCK_N: 128,
LOAD_ELEMS_PER_THREAD: 2,
STORE_ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=False,
run_config=config,
):
print(test(a, b, c).module_op)
# CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %{{.+}}: !stream.binding)
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index

# Slicing LHS
# CHECK: %[[LHS:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<256x128xf16
# CHECK: %[[LHS_0:.+]] = vector.load %[[LHS]][%[[X_SLICE_0:.+]], %[[Y_SLICE:.+]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16>
# CHECK: %[[X_SLICE_1:.+]] = arith.addi %[[X_SLICE_0]], %c1 : index
# CHECK: %[[LHS_1:.+]] = vector.load %[[LHS]][%[[X_SLICE_1]], %[[Y_SLICE]]] : memref<256x128xf16, strided<[128, 1], offset: ?>>, vector<2xf16>

# Slicing RHS
# CHECK: %[[RHS:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<256xf16
# CHECK: %[[RHS_0:.+]] = vector.load %[[RHS]][%[[X_SLICE_0]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>
# CHECK: %[[RHS_1:.+]] = vector.load %[[RHS]][%[[X_SLICE_1]]] : memref<256xf16, strided<[1], offset: ?>>, vector<1xf16>

# 1st Broadcast-ADD RHS
# CHECK: %[[EXTRACT_0:.+]] = vector.extract %[[RHS_0]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_0:.+]] = vector.splat %[[EXTRACT_0]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_0]], %[[BCAST_RHS_0]] : vector<2xf16>

# 2nd Broadcast-ADD RHS
# CHECK: %[[EXTRACT_1:.+]] = vector.extract %[[RHS_1]][0] : f16 from vector<1xf16>
# CHECK: %[[BCAST_RHS_1:.+]] = vector.splat %[[EXTRACT_1]] : vector<2xf16>
# CHECK: arith.addf %[[LHS_1]], %[[BCAST_RHS_1]] : vector<2xf16>


@run_test
def test_binary_lowerings():
constraints: list[tkw.Constraint] = [
Expand Down
61 changes: 55 additions & 6 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def maximum(lhs: "Register", rhs: "Register") -> "Register":
...


def broadcast(
arg: "Register", target_shape: Optional[IndexExpr | int] = None
) -> "Register":
...


def sum(
src: "Register",
acc: Optional["Register"] = None,
Expand Down Expand Up @@ -496,7 +502,12 @@ def post_expansion(self, constraints: list["Constraint"]) -> None:
@dataclass
class BinaryPyOp(CustomOp, ABC):
"""
Represents a binary python operator.
Represents an elementwise binary python operator.
DTYPE requirement: lhs and rhs needs to have the same dtpye.
Shape requirement: lhs and rhs either have same shape or
their shape must be broadcastable to
one another.
"""

lhs: Any
Expand All @@ -522,9 +533,16 @@ def type(self) -> Memory:
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if not has_same_type:
raise ValueError("Expected lhs and rhs to have same type post-expansion")
return lhs_type
if has_same_type:
return lhs_type
lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype
return broadcasted_type


@define_interface_op("exp2")
Expand Down Expand Up @@ -899,8 +917,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
return captured_vars

@property
def type(self) -> list[Memory | Register]:
return [get_custom(x).type for x in self.init_args]
def type(self) -> Memory | Register | list[Memory | Register]:
res_types = [get_custom(x).type for x in self.init_args]
if len(res_types) == 1:
res_types = res_types[0]
return res_types

def outputs(self, graph: fx.Graph) -> list[fx.Node]:
for node in graph.nodes:
Expand Down Expand Up @@ -1022,6 +1043,34 @@ def type(self) -> "Register":
return get_custom(self.register_).type


@define_op("broadcast")
@dataclass
class Broadcast(CustomOp, ABC):
"""
Represents a Broadcast operation.
arg: Source tensor/value to broadcast
target_shape: symbolic target broadcast shape.
"""

arg: fx.Node
target_type: Sequence[IndexSymbol] = None

@property
def target_shape(self):
return self.target_type.symbolic_shape

@property
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape

@property
def type(self) -> Memory:
src_dtype = get_custom(self.arg).type.dtype
dst_type = Register[*self.target_shape, src_dtype]
return dst_type


@define_interface_op("max")
@define_interface_op("sum")
@dataclass
Expand Down
39 changes: 38 additions & 1 deletion shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from shark_turbine.kernel.lang.global_symbols import *
from ..ops.wave_ops import (
write,
broadcast,
register,
mma,
shuffle,
Expand Down Expand Up @@ -79,7 +80,7 @@
WorkgroupConstraint,
TilingConstraint,
)
from .utils import subs_idxc, find_index_bounds
from .utils import subs_idxc, find_index_bounds, get_hardware_vector_map

# Indexing imports.
from .._support.indexing import IndexingContext, IndexExpr, IndexSequence
Expand Down Expand Up @@ -1095,6 +1096,42 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node):
emitter.bind_node_proxy(node, IRProxyValue(element))


###############################################################################
# Reshape ops
###############################################################################


@handle_op(broadcast)
def handle_broadcast(emitter: WaveEmitter, node: fx.Node):
try:
register, target_type = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

# Get thread_shape/size for broadcast.
get_thread_shape = lambda index: max(x.size for x in index.values())
bcast_dim_lane_dim_size = get_thread_shape(node.index)

# Check MLIR shape
vector_src = cast_vector(emitter, register)
vector_type = vector_src.type
# Only support broadcasting vector<1xdtype> for now.
if not VectorType.isinstance(vector_type):
raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.")
assert vector_type.rank == 1
assert vector_type.shape[0] == 1

# Extract and Splat
# If by chance broadcast size matches current size, we can return src.
if bcast_dim_lane_dim_size == vector_type.shape[0]:
emitter.bind_node_proxy(node, IRProxyValue(vector_src))

result_type = VectorType.get([bcast_dim_lane_dim_size], vector_type.element_type)
element = vector_d.extract(vector_src, static_position=[0], dynamic_position=[])
splat = vector_d.splat(result_type, element)
emitter.bind_node_proxy(node, IRProxyValue(splat))


###############################################################################
# Miscellanous ops
###############################################################################
Expand Down
47 changes: 43 additions & 4 deletions shark_turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.fx as fx
from ..ops.wave_ops import *
from ..lang.global_symbols import *
from .utils import capture_forward_slice, capture_backward_slice
from .utils import capture_forward_slice, capture_backward_slice, subs_idxc

logger = get_logger("turbine.wave.thread_shape_analysis")

Expand All @@ -24,7 +24,9 @@ def __hash__(self):


def get_dim_sizes(indices: list[IndexSequence]):
dims = frozenset([DimSize(dim, seq.size) for dim, seq in indices.items()])
dims = frozenset(
[DimSize(dim, subs_idxc(seq.size)) for dim, seq in indices.items()]
)
return dims


Expand All @@ -41,6 +43,39 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
custom.index[target.dim].size = target.size


def handle_binaryop_conflict(custom_node: CustomOp):
# Analyze if we can resolve conflict with broadcast.
lhs = get_custom(custom_node.lhs)
rhs = get_custom(custom_node.rhs)
lhs_dim_set = set(lhs.type.symbolic_shape)
rhs_dim_set = set(rhs.type.symbolic_shape)
if lhs_dim_set == rhs_dim_set:
raise ValueError("Cannot broadcast if lhs and rhs is already same.")
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.")
# Determine the correct indexSize for binaryOp and insert broadcasting.
dst_op = lhs if lhs_dim_set > rhs_dim_set else rhs
broadcast_idx, broadcast_src = (1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs)
broadcast = Broadcast(broadcast_src.fx_node, dst_op.type)
with custom_node.graph.inserting_before(custom_node.fx_node):
broadcast.add_to_graph(custom_node.graph)
setattr(broadcast.fx_node, "index", dst_op.index)
custom_node.index = dst_op.index
custom_node.update_arg(broadcast_idx, broadcast.fx_node)
return True


# Returns True iff all conflicts are handled succesfully.
def handle_conflicts(conflicted_ops: set[CustomOp]):
for conflict in conflicted_ops:
custom = get_custom(conflict)
if isinstance(custom, BinaryPyOp):
handle_binaryop_conflict(custom)
else:
return False
return True


def determine_thread_shapes(trace: CapturedTrace):
"""
This function does analysis and propagation of thread shape. It does by such:
Expand Down Expand Up @@ -133,10 +168,14 @@ def propagatable_op(node: fx.Node):
# Go through each index-size buckets, and apply the index-size to ops in the bucket.
cummulative_set = set()
for target_index_size, target_ops in thread_size_to_ops.items():
# Ensure that we do not have any conflicts.
# Try to handle conflicts and remove from target set if successfully handled.
if not cummulative_set.isdisjoint(target_ops):
raise NotImplementedError("NYI: Handling of conflicting thread shape.")
conflicted_ops = cummulative_set.intersection(target_ops)
if handle_conflicts(conflicted_ops) == False:
raise NotImplementedError("Failed to handle conflicting thread shape.")
target_ops = target_ops.difference(conflicted_ops)
cummulative_set = cummulative_set.union(target_ops)
# Set target ops's indexSize to be the determined from analysis.
for user in target_ops:
custom_user = get_custom(user)
set_index_size(custom_user, target_index_size)

0 comments on commit 7617c94

Please sign in to comment.