Skip to content

Commit

Permalink
Add flash decode v2 (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod authored Dec 20, 2024
1 parent ef8142e commit 20507b7
Show file tree
Hide file tree
Showing 10 changed files with 565 additions and 328 deletions.
5 changes: 5 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def exp2(src: "Register") -> "Register":
...


def log2(src: "Register") -> "Register":
...


def reciprocal(src: "Register") -> "Register":
...

Expand Down Expand Up @@ -653,6 +657,7 @@ def infer_type(self):
self.type = broadcasted_type


@define_interface_op("log2")
@define_interface_op("exp2")
@define_interface_op("reciprocal")
@define_interface_op("abs")
Expand Down
17 changes: 16 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch.utils._pytree as pytree
from collections import namedtuple

from .symbolic_constraints import SymbolicAlias

from ..compiler.ir import (
Attribute,
DenseElementsAttr,
Expand Down Expand Up @@ -55,6 +57,7 @@
read,
reduction,
exp2,
log2,
reciprocal,
abs,
maximum,
Expand Down Expand Up @@ -491,7 +494,7 @@ def handle_register(emitter: WaveEmitter, node: fx.Node):
shape, dtype, value = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
get_thread_shape = lambda index: max(x.size for x in index.values())
get_thread_shape = lambda index: max(subs_idxc(x.size) for x in index.values())
shape = [get_thread_shape(get_custom(node).index)]
vector_shape = cast_py_literal(emitter, shape)
element_type = IrType.parse(dtype.ir_type_asm())
Expand Down Expand Up @@ -1120,6 +1123,16 @@ def handle_exp2(source: Value) -> OpResult:
return result


@handle_unary_op(log2)
def handle_log2(source: Value) -> OpResult:
element_type = get_type_or_element_type(source.type)
if _is_float_type(element_type):
result = math_d.log2(source)
else:
raise ValidationError(f"Found unhandled operand type for exp2: {element_type}")
return result


@handle_unary_op(reciprocal)
def handle_reciprocal(source: Value) -> OpResult:
element_type = get_type_or_element_type(source.type)
Expand Down Expand Up @@ -1351,6 +1364,8 @@ def handle_get_result(emitter: WaveEmitter, node: fx.Node):

@handle_op(operator.getitem)
def handle_getitem(emitter: WaveEmitter, node: fx.Node):
if not node.users:
return
raise NotImplementedError("getitem: Currently only stub implementation")


Expand Down
12 changes: 11 additions & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, TypeAlias, Sequence, Type, Callable
from functools import partial

from .symbolic_constraints import SymbolicAlias

from .constraints import (
Constraint,
HardwareConstraint,
Expand All @@ -17,6 +19,7 @@
)
from ..ops.wave_ops import (
Allocate,
BinaryPyOp,
CustomOp,
GetResult,
Getitem,
Expand Down Expand Up @@ -647,11 +650,19 @@ def get_dim_scaling(
if len(hardware_constraints) != 1:
raise ValueError("Exactly one hardware constraint must be provided")

aliased_dims: list[IndexSymbol] = [
constraint.source
for constraint in constraints
if isinstance(constraint, SymbolicAlias)
]

idxc = IndexingContext.current()
for constraint in constraints:
if isinstance(constraint, WorkgroupConstraint) or isinstance(
constraint, TilingConstraint
):
if constraint.dim in aliased_dims:
continue
hw_cons = hardware_constraints[0]
tile_size = idxc.get_static_value(constraint.tile_size)
if constraint.dim not in node.vector_shapes:
Expand Down Expand Up @@ -740,7 +751,6 @@ def _handle_reduction_dim(
):
# Rediscover iter args
# TODO: Register iter args with the reduction initially so accessing them is easier
iter_args: list[CustomOp] = []
reduction_subgraph = trace.get_subgraph(reduction.subgraph_name)

# TODO: Handle case where MMAs/ReduceOps do not have Output as direct consumer.
Expand Down
121 changes: 104 additions & 17 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
IterArg,
CustomOp,
Reshape,
Output,
)
from .constraints import (
Constraint,
HardwareConstraint,
WorkgroupConstraint,
)
from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint
from .assumptions import Assumption
from .symbolic_constraints import SymbolicAlias
from .._support.tracing import CapturedTrace, IndexingContext
from .._support.indexing import IndexSymbol, IndexSequence
from ..lang.global_symbols import *
Expand All @@ -34,7 +40,6 @@
get_inputs,
get_users,
get_largest_index_and_size,
capture_backward_slice,
)
import torch.fx as fx
import numpy as np
Expand All @@ -57,6 +62,20 @@ def get_vector_shape(
return vector_shapes


def _get_symbolic_shape_and_vector_shapes(
custom: CustomOp,
aliases: dict[IndexSymbol, SymbolicAlias],
hw_constraint: HardwareConstraint,
):
# When the memory type has symbolic aliases, use the memory type
# as it includes the aliased variables.
symbolic_shape = custom.register_type.symbolic_shape
vector_shapes = custom.vector_shapes
if any([x in custom.memory_type.symbolic_shape for x in aliases]):
symbolic_shape = custom.memory_type.symbolic_shape
return symbolic_shape, vector_shapes


def partition_strided_operators(trace: CapturedTrace, constraints: list[Constraint]):
"""
This function analyzes the index sequence of operators in the graph
Expand Down Expand Up @@ -92,16 +111,19 @@ def has_strided_access(node: fx.Node) -> bool:

strided_operators = trace.walk(has_strided_access)
hw_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)][0]
aliases = {c.source: c for c in constraints if isinstance(c, SymbolicAlias)}
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
for dim in custom.index
}

shape = get_vector_shape(
custom.vector_shapes, custom.register_type.symbolic_shape
symbolic_shape, vector_shapes = _get_symbolic_shape_and_vector_shapes(
custom, aliases, hw_constraint
)

shape = get_vector_shape(vector_shapes, symbolic_shape)
elements_per_thread = subs_idxc(custom.elements_per_thread)
max_stride_dim, max_stride = max(
[(dim, seq.stride) for dim, seq in simplified_index.items()],
Expand All @@ -126,7 +148,7 @@ def has_strided_access(node: fx.Node) -> bool:
dim: IndexSequence(
simplified_index[dim].start.subs({GPR_NUM: 0}) + offset[j], 1, 1
)
for j, dim in enumerate(custom.register_type.symbolic_shape)
for j, dim in enumerate(symbolic_shape)
}
ops_to_combine.append(write)

Expand Down Expand Up @@ -312,12 +334,30 @@ def set_derived_index(trace):
worklist.append((inp, new_index))


def verify_nodes(trace: CapturedTrace):
"""
Verify that all the valid nodes have their index and vector shapes set.
"""
nodes = trace.walk(lambda x: x)
for node in nodes:
custom = get_custom(node)
if isinstance(custom, (Placeholder, Allocate)) and not isinstance(
custom, IterArg
):
continue
if isinstance(custom, (Output, Reduction)):
continue
assert custom.index, f"Index not set for node {custom.fx_node}"
assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}"


def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]):
mma_index = get_mma_dimensional_mapping(trace, get_hardware_constraint(constraints))
trace.walk(partial(set_thread_independent_index, constraints))
set_thread_dependent_index(constraints, mma_index, trace)
set_derived_index(trace)
resolve_thread_shapes(trace, constraints)
verify_nodes(trace)


def compute_stride(
Expand Down Expand Up @@ -417,8 +457,11 @@ def set_thread_independent_index(
if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg):
return

hw_cons = get_hardware_constraint(constraints)
constraints = [
c for c in constraints if not isinstance(c, (HardwareConstraint, Assumption))
c
for c in constraints
if not isinstance(c, (HardwareConstraint, Assumption, SymbolicAlias))
]

index = {}
Expand Down Expand Up @@ -577,11 +620,21 @@ def should_update_index(
source: CustomOp,
source_index: dict[IndexSymbol, IndexSequence],
source_vector_shapes: dict[IndexSymbol, int],
symbolic_constraints: list[SymbolicAlias],
):
# Get symbolic shape without any aliased variables.
aliased_dims = [x.source for x in symbolic_constraints]
symbolic_shape = set(source.type.symbolic_shape).difference(aliased_dims)

# If all the source indexing dimensions are not present in source vector shapes,
# we should not update the index.
if not set(symbolic_shape).issubset(set(source_vector_shapes.keys())):
return False

# Determine if we should update the idx based on the source.
# We update the source only if the source index provides
# information about all the non-batch dimensions of the source.
non_batch_dims = [x for x in source.indexing_dims if source_vector_shapes[x] > 1]
non_batch_dims = [x for x in symbolic_shape if source_vector_shapes[x] > 1]

# If the source index is smaller than the non-batch dims, check if the
# source index is a subset of the non-batch dims.
Expand All @@ -595,12 +648,28 @@ def should_update_index(
return True


def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]):
"""
Append the aliased shapes to the vector shapes of the source, if they
are present in the source index.
"""
for constraint in symbolic_constraints:
if (
constraint.target in source.vector_shapes
and constraint.source in source.index
):
source.vector_shapes[constraint.source] = constraint.apply(
source.vector_shapes[constraint.target]
)


def propagate_index(
node: CustomOp,
hardware_constraint: HardwareConstraint,
workgroup_constraints: list[WorkgroupConstraint],
mma_index: dict[MMA, dict[IndexSymbol, int]],
visited: set[CustomOp],
symbolic_constraints: list[SymbolicAlias],
):
"""
Propagate the index and vector shapes through the graph
Expand All @@ -619,11 +688,14 @@ def propagate_index(
if source in visited:
continue
if not isinstance(source, (Reduction, MMA)):
if not should_update_index(source, source_index, source_vector_shapes):
if not should_update_index(
source, source_index, source_vector_shapes, symbolic_constraints
):
continue
source_index = source.transform_index(source_index)
source.index = combine_indices(source.index, source_index)
source.vector_shapes = source_vector_shapes
append_aliased_shapes(source, symbolic_constraints)
visited.add(source)
for func in [get_inputs, get_users]:
sources, reduction = add_nodes_to_sources(
Expand Down Expand Up @@ -656,11 +728,17 @@ def set_thread_dependent_index(
workgroup_constraints = [
c for c in constraints if isinstance(c, WorkgroupConstraint)
]
symbolic_constraints = [c for c in constraints if isinstance(c, SymbolicAlias)]
for source in sources:
visited = visited.union(set([x for x in sources]))
visited.remove(source)
visited = propagate_index(
source, hardware_constraint, workgroup_constraints, mma_index, visited
source,
hardware_constraint,
workgroup_constraints,
mma_index,
visited,
symbolic_constraints,
)


Expand Down Expand Up @@ -696,7 +774,7 @@ def create_broadcast(
binary_op.graph
)
custom = get_custom(broadcasted)
custom.vector_shapes = to_broadcast.vector_shapes
custom.vector_shapes = binary_op.vector_shapes
custom.index = deepcopy(target_node.index)
custom.index[broadcast_dim].size = broadcast_size
broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast)
Expand All @@ -712,14 +790,21 @@ def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]):
Currently, the only mismatches that can be resolved are when one of
the shapes is 1 and the other is > 1.
"""

def get_index(custom: CustomOp):
if isinstance(custom, MMA):
return custom.acc.index
return custom.index

binary_ops = trace.walk(lambda node: isinstance(get_custom(node), BinaryPyOp))
for binary_op in binary_ops:
custom = get_custom(binary_op)
# Get the largest dim and shape from the lhs and rhs.
lhs = get_custom(custom.lhs)
rhs = get_custom(custom.rhs)
lhs_dim, lhs_size = get_largest_index_and_size(lhs.index)
rhs_dim, rhs_size = get_largest_index_and_size(rhs.index)

lhs_dim, lhs_size = get_largest_index_and_size(get_index(lhs))
rhs_dim, rhs_size = get_largest_index_and_size(get_index(rhs))

# If they are equal we are done.
if lhs_dim == rhs_dim and lhs_size == rhs_size:
Expand All @@ -737,22 +822,24 @@ def resolve_thread_shapes(trace: CapturedTrace, constraints: list[Constraint]):
to_broadcast = rhs if broadcast_rhs else lhs
broadcast_dim = lhs_dim if broadcast_rhs else rhs_dim
broadcast_size = lhs_size if broadcast_rhs else rhs_size
target = lhs if broadcast_rhs else rhs
broadcasted = lhs if broadcast_rhs else rhs

if lhs_dim != rhs_dim:
# If the dimensions don't agree, we can still do this broadcast only if
# the two nodes differ in shape along the broadcasting dimension and the
# broadcasting dimension is the innermost dimension.
missing_dims = set(target.indexing_dims).difference(
set(to_broadcast.indexing_dims)
missing_dims = set(broadcasted.type.symbolic_shape).difference(
set(to_broadcast.type.symbolic_shape)
)
is_only_missing_dim = missing_dims == {broadcast_dim}
is_innermost_dim = broadcast_dim == target.indexing_dims[-1]
is_innermost_dim = broadcast_dim == broadcasted.type.symbolic_shape[-1]

if not is_only_missing_dim and not is_innermost_dim:
raise NotImplementedError(
"Currently only support resolving discrepancies when the broadcasting dimension is the innermost dimension."
)

# Broadcast
create_broadcast(custom, to_broadcast, broadcast_dim, broadcast_size, target)
create_broadcast(
custom, to_broadcast, broadcast_dim, broadcast_size, broadcasted
)
Loading

0 comments on commit 20507b7

Please sign in to comment.