Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get GEMMs working without minimize_global_loads #167

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 96 additions & 89 deletions lit_tests/kernel/wave/codegen.py

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,23 +243,23 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16})
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b])
# CHECK-NEXT: get_result(value=reduction, res_idx=3)
# CHECK-NEXT: get_result(value=reduction, res_idx=2)
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_1_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: write(register_=getresult_1_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}
# CHECK-NEXT: write(register_=getresult_0_1_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 16}
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}
# CHECK-NEXT: output

# Reduction subgraph:
Expand Down Expand Up @@ -389,11 +389,11 @@ def test_gemm_reduction_expansion_only():
# CHECK-NEXT: placeholder(_name=a
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0]
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N})
# CHECK-SAME: index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)})
# CHECK-NEXT: output(return_vals=(None,))

# Reduction subgraph:
Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: allocate(
# CHECK-NEXT: allocate(
# CHECK-NEXT: reduction(
Expand Down
16 changes: 8 additions & 8 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def test_gemm():
# CHECK-NEXT: placeholder(_name=b
# CHECK-NEXT: placeholder(_name=c
# CHECK-NEXT: register
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: register(
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: allocate(
# CHECK-NEXT: allocate(
# CHECK-NEXT: reduction(
Expand All @@ -146,13 +146,13 @@ def test_gemm():
# CHECK-NEXT: get_result(value=reduction, res_idx=1)
# CHECK-NEXT: get_result(value=reduction, res_idx=0)
# CHECK-NEXT: write(register_=getresult_0_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_1_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})
# CHECK-NEXT: write(register_=getresult_1_0_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M + 16, N: $WG1*BLOCK_N + BLOCK_N/2})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)})
# CHECK-NEXT: write(register_=getresult_0_1_0, memory=c
# CHECK-SAME: index={M: $WG0*BLOCK_M, N: $WG1*BLOCK_N + BLOCK_N/2 + 16})
# CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16})

# Reduction subgraph:
# CHECK: %acc_0_0_0
Expand Down
10 changes: 0 additions & 10 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,16 +780,6 @@ def custom_string(self, value_map: dict[str, str]) -> str:
custom_str += f"acc={self.acc} (index = {self.acc_index}))"
return custom_str

def post_expansion(self, constraints: list["Constraint"]) -> None:
"""
Once the arguments have been expanded, we set their indices,
ensuring that the LHS and RHS indices are consistent with their
corresponding address spaces.
"""
self.lhs.index = self.lhs_index
self.rhs.index = self.rhs_index
self.acc.index = self.acc_index


@define_op("read")
@dataclass
Expand Down
15 changes: 7 additions & 8 deletions shark_turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .._support.indexing import IndexingContext, IndexSequence
from ...support.logging import get_logger
from .._support.tracing import CapturedTrace
from .utils import get_mma_dimensional_mapping
from .utils import get_mma_dimensional_mapping, specialize_index_sequence
from ..lang.global_symbols import *

logger = get_logger("turbine.wave.expansion")
Expand Down Expand Up @@ -141,6 +141,7 @@ def compute_stride(
def set_node_index(
constraints: Sequence[Constraint],
mma_index: dict[IndexSymbol, int],
mma_slices: dict[IndexSymbol, list[fx.Node]],
dim_tile_size: dict[IndexSymbol, int],
custom: CustomOp,
dim_scaling: dict[IndexSymbol, int],
Expand Down Expand Up @@ -171,11 +172,7 @@ def set_node_index(
for dim in custom.indexing_dims:
index_seq = None
for constraint in sorted_constraints:
mma_check = (
isinstance(constraint, HardwareConstraint)
and dim in mma_index
and isinstance(custom, MMA)
)
mma_check = isinstance(constraint, HardwareConstraint) and dim in mma_index

vector_check = (
isinstance(constraint, HardwareConstraint)
Expand Down Expand Up @@ -217,6 +214,8 @@ def set_node_index(
index_seq = constraint.apply(
constraint_index, dim, elements_per_thread, stride
)
if mma_index:
index_seq = specialize_index_sequence(index_seq, mma_slices, custom)

else:
if index_seq is None:
Expand Down Expand Up @@ -246,10 +245,10 @@ def expand_graph(
dim_scaling = constraints_or_scaling
node_index_setter = lambda *args: None
else:
mma_index = get_mma_dimensional_mapping(trace)
mma_index, mma_slices = get_mma_dimensional_mapping(trace)
dim_scaling, dim_tile_size = get_dim_scaling(constraints_or_scaling, mma_index)
node_index_setter = partial(
set_node_index, constraints_or_scaling, mma_index, dim_tile_size
set_node_index, constraints_or_scaling, mma_index, mma_slices, dim_tile_size
)

# Start from the back and expand in the corresponding indexing dimensions of a node
Expand Down
2 changes: 1 addition & 1 deletion shark_turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_vector_shape(
hardware_constraint: HardwareConstraint,
symbolic_shape: list[IndexSymbol],
) -> list[int]:
mma_indices = get_mma_dimensional_mapping(trace)
mma_indices, _ = get_mma_dimensional_mapping(trace)
return [
get_hardware_vector_size(dim, hardware_constraint, mma_indices)
for dim in symbolic_shape
Expand Down
150 changes: 145 additions & 5 deletions shark_turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -16,7 +15,16 @@
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence
from ..lang.global_symbols import *
from ..ops.wave_ops import get_custom, Output, Write, MMA
from ..ops.wave_ops import (
get_custom,
Output,
Write,
MMA,
CustomOp,
Reduction,
GetResult,
IterArg,
)
from .constraints import Constraint, HardwareConstraint, TilingConstraint
import torch.fx as fx
import shark_turbine.kernel.lang as tkl
Expand Down Expand Up @@ -145,7 +153,9 @@ def simplify_index(index: IndexExpr) -> IndexExpr:
return subs_idxc(index.subs(mapping))


def get_mma_dimensional_mapping(trace: CapturedTrace) -> dict[IndexSymbol, int]:
def get_mma_dimensional_mapping(
trace: CapturedTrace,
) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]:
"""
Given a trace, determine the MMA dimensional mapping for all the
MMA operations in the graph. For example, if we have
Expand All @@ -159,7 +169,8 @@ def is_mma(node):
return isinstance(get_custom(node), MMA)

mapping: dict[IndexSymbol, int] = {}
for node in trace.walk(is_mma):
mma_nodes = trace.walk(is_mma)
for node in mma_nodes:
custom: MMA = get_custom(node)
m, n = custom.acc_type.symbolic_shape[-2:]
lhs_shape = custom.lhs_type.symbolic_shape
Expand All @@ -170,7 +181,7 @@ def is_mma(node):
mapping[n] = 1
mapping[k] = 2

return mapping
return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes])


def get_hardware_vector_size(
Expand Down Expand Up @@ -378,3 +389,132 @@ def erase_graph(graph: fx.Graph):
for user in node.users:
graph.erase_node(user)
graph.erase_node(node)


def get_users(
node: fx.Node, reduction: fx.Node = None
) -> tuple[list[fx.Node], fx.Node]:
"""
Return the users of a node, propagating through reductions.
"""
users = []
for user in node.users:
custom = get_custom(user)
if isinstance(custom, Reduction):
# Map init arg to iter arg
reduction = custom
init_arg_idx = custom.init_args.index(node)
users.append(custom.iter_args[init_arg_idx])
continue
if isinstance(custom, Output) and reduction:
# Map output to get result
return_vals = custom.return_vals[0]
get_results = sorted(
[x for x in reduction.users if isinstance(get_custom(x), GetResult)],
lambda x: get_custom(x).res_idx,
)
if isinstance(return_vals, list):
output_idx = return_vals.index(node)
users.append(get_results[output_idx])
else:
users.append(get_results[0])
continue
users.append(user)
return users, reduction


def get_inputs(
node: fx.Node, reduction: fx.Node = None
) -> tuple[list[fx.Node], fx.Node]:
"""
Return the inputs of a node, propagating through reductions.
"""
inputs = []
for input in node.all_input_nodes:
custom = get_custom(input)
if isinstance(custom, GetResult):
reduction = custom.value
assert isinstance(
reduction, Reduction
), "GetResult must be used by a Reduction"
# Map get result to output
inputs.append(reduction.outputs[custom.res_idx])
continue
if isinstance(custom, IterArg):
# Map iter args to init args
iter_arg_idx = reduction.iter_args.index(node)
inputs.append(reduction.init_args[iter_arg_idx])
continue
inputs.append(input)
return inputs, reduction


def bfs(
node: fx.Node,
get_neighbors: Callable[[fx.Node, fx.Node], list[fx.Node]],
) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
"""
visited: set[fx.Node] = set()
queue: list[fx.Node] = []
visited.add(node)
queue.append(node)
reduction = None
while queue:
s = queue.pop(0)
neighbors, reduction = get_neighbors(s, reduction)
for neighbor in neighbors:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return visited


def capture_forward_slice(node: fx.Node) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
"""
return bfs(node, lambda x, y: get_users(x, y))


def capture_backward_slice(node: fx.Node) -> set[fx.Node]:
"""
Capture backward slice from a node and return the tree.
Assumes graph is directed.
"""
return bfs(node, lambda x, y: get_inputs(x, y))


def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]:
"""
Given an index sequence, specialize it to a LHS, RHS or ACC index sequence
based on whether the node is used as the LHS, RHS or ACC in the MMA node.
"""
mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]}
for mma in mma_nodes:
mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs)
mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs)
mma_slices[MMA_ACC] += capture_forward_slice(mma.acc)
return mma_slices


def specialize_index_sequence(
index_seq: IndexSequence,
mma_slices: dict[IndexSymbol, list[fx.Node]],
custom: CustomOp,
) -> IndexSequence:
"""
Given an index sequence, specialize it to a LHS, RHS or ACC index sequence
based on whether the node is used as the LHS, RHS or ACC in the MMA node.
If the node is not used as any of the operands, return the original index sequence
with all the MMA symbols zeroed out.
"""
if isinstance(custom, MMA):
return index_seq
operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0}
for key in mma_slices:
if custom.fx_node in mma_slices[key]:
operand_map[key] = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we want for every for key in mma_slices to be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0} then
{MMA_LHS: 0, MMA_RHS: 1, MMA_ACC: 0} and then
{MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1},

But in current state wouldn't this be
{MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
{MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}

Although if it that is indeed what we are going for, can you explaine intuition behind it? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if a node is determined to be in the backward slice of the LHS, then we want to specialize it by substituting {MMA_LHS = 1, all else 0}. For RHS, we want {MMA_RHS = 1, all else 0}. For ACC, {MMA_ACC = 1, all else 0}. And if its not in the backward slices of the LHS and RHS or forward slice of the ACC, then {all = 0}. You can think of this as an alternative to propagation. Because we set the entire indices everywhere, we need to specialize them depending on some constraints, and for that we use the forward/backward slices of the MMA operands.

Copy link
Contributor

@raikonenfnu raikonenfnu Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, in that case I think we need to move the operand_map = {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 0} above the if custom.fx_node in mma_slices[key]:. Otherwise the previous state carry over. i.e we will get:

iter_0 setting MMA_LHS, {MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}
iter_1 setting MMA_RHS,  {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 0}
iter_2 setting MMA_ACC, {MMA_LHS: 1, MMA_RHS: 1, MMA_ACC: 1}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized I put in the wrong state on the previous comment, updated it to make it make more sense haha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, we would get carry over, except for the fact that we return as soon as we get a match. So that guarantees that we our dictionary's values will always only have one non-zero entry (= 1) .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense that's why it's implicitly functionally equivalent. Can we still bring it down though for better clarity/straightforward-ness? :)

return index_seq.subs(operand_map)
return index_seq.subs(operand_map)
Loading