Skip to content

Commit

Permalink
Reduce fusion overhead (#284)
Browse files Browse the repository at this point in the history
* type_inference: reduce overhead by not checking list_passed and restoring a workunit's original arguments when fusing (if the workunit already has all its types)

* Tracer: memoize retrieval of safety info to reduce overhead and relax restriction on TIDFunc accesses
  • Loading branch information
NaderAlAwar authored Aug 5, 2024
1 parent 437ffb7 commit 359b08a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 26 deletions.
12 changes: 6 additions & 6 deletions pykokkos/core/fusion/access_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, tid_name: str, view_args: Dict[str, int]):
self.view_args = view_args

# Map from each view (str) + dimension (int) to an AccessIndex
self.access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
self.access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}
self.current_iters: List[Tuple[str, bool]] = []

def visit_For(self, node: ast.For) -> None:
Expand All @@ -119,7 +119,7 @@ def visit_Call(self, node: ast.Call) -> None:
if arg.id in self.view_args:
rank: int = self.view_args[arg.id]
for i in range(rank):
self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite)
self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite, "")

def visit_Subscript(self, node: ast.Subscript) -> None:
current_node: ast.Subscript = node
Expand Down Expand Up @@ -160,7 +160,7 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
index_to_set: AccessIndex
mode_to_set: AccessMode

existing_access: Optional[Tuple[AccessIndex, AccessMode]] = self.access_indices.get((view_name, i))
existing_access: Optional[Tuple[AccessIndex, AccessMode, str]] = self.access_indices.get((view_name, i))
if existing_access is None:
index_to_set = new_index
mode_to_set = AccessMode.Read if isinstance(node.ctx, ast.Load) else AccessMode.Write
Expand Down Expand Up @@ -191,10 +191,10 @@ def visit_Subscript(self, node: ast.Subscript) -> None:
if mode_to_set is AccessMode.Write and isinstance(node.parent, ast.AugAssign):
mode_to_set = AccessMode.ReadWrite

self.access_indices[(view_name, i)] = (index_to_set, mode_to_set)
self.access_indices[(view_name, i)] = (index_to_set, mode_to_set, index_node_str)


def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str, int]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]:
def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str, int]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Get information from the AST needed for fusion safety
Expand All @@ -207,6 +207,6 @@ def get_view_write_indices_and_modes(AST: ast.FunctionDef, view_args: Dict[str,
tid_name: str = AST.args.args[0].arg
visitor = WriteIndicesVisitor(tid_name, view_args)
visitor.visit(AST)
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = visitor.access_indices
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = visitor.access_indices

return access_indices
59 changes: 41 additions & 18 deletions pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TracerOperation:
entity_name: str
args: Dict[str, Any]
dependencies: Set[DataDependency]
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

def __hash__(self) -> int:
return self.op_id
Expand Down Expand Up @@ -81,6 +81,10 @@ def __init__(self) -> None:
# Map from data version to tracer operation
self.data_operation: Dict[DataDependency, TracerOperation] = {}

# Cache expensive operations that require traversing the AST
self.access_modes_cache: Dict[Tuple[str, str], Dict[str, AccessMode]] = {}
self.safety_cache: Dict[Tuple[str, str], Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]] = {}

def log_operation(
self,
future: Optional[Future],
Expand Down Expand Up @@ -108,10 +112,19 @@ def log_operation(
entity: PyKokkosEntity = parser.get_entity(entity_name)
AST: ast.FunctionDef = entity.AST

cache_key: Tuple[str, str] = (parser.path, entity_name)

dependencies: Set[DataDependency]
access_modes: Dict[str, AccessMode]
dependencies, access_modes = self.get_data_dependencies(kwargs, AST)
access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = self.get_safety_info(kwargs, AST)
dependencies, access_modes = self.get_data_dependencies(kwargs, AST, cache_key)

access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

if cache_key in self.safety_cache:
access_indices = self.safety_cache[cache_key]
else:
access_indices = self.get_safety_info(kwargs, AST)
self.safety_cache[cache_key] = access_indices

tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies, access_indices)
self.op_id += 1
Expand All @@ -120,7 +133,7 @@ def log_operation(

self.operations[tracer_op] = None

def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]]:
def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Get the view access indices needed to check for safety
Expand All @@ -141,10 +154,10 @@ def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[

# Map from view name (str) + dimension (int) to the type of
# access to that view's dimension
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = get_view_write_indices_and_modes(AST, view_name_and_rank)
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = get_view_write_indices_and_modes(AST, view_name_and_rank)

# Now need to convert view name to view ID
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}
for (name, dim), access_index in write_indices.items():
view_id: int = view_args[name]
safety_info[(view_id, dim)] = access_index
Expand Down Expand Up @@ -225,7 +238,7 @@ def fuse(self, operations: List[TracerOperation], strategy: str) -> List[TracerO

raise RuntimeError(f"Unrecognized fusion strategy '{strategy}'")

def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[ViewType], current_safety_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], next: TracerOperation, next_views: Set[ViewType]) -> bool:
def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[ViewType], current_safety_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]], next: TracerOperation, next_views: Set[ViewType]) -> bool:
"""
Check whether the next operation is safe to fuse with the
current operations
Expand All @@ -244,16 +257,20 @@ def is_safe_to_fuse(self, current: List[TracerOperation], current_views: Set[Vie
for dim in range(view.rank()):
key: Tuple[int, int] = (id(view), dim)

# assert key in current_safety_info and key in next_safety_info
assert key in current_safety_info
assert key in next_safety_info

current_access_index, current_access_mode = current_safety_info[key]
next_access_index, next_access_mode = next_safety_info[key]
current_access_index, current_access_mode, current_index_str = current_safety_info[key]
next_access_index, next_access_mode, next_index_str = next_safety_info[key]

if current_access_mode == AccessMode.Read and next_access_mode == AccessMode.Read:
continue

# If the same function on the thread index is used to
# index both views then this will not prevent fusion.
if current_access_index == AccessIndex.TIDFunc and next_access_index == AccessIndex.TIDFunc and current_index_str == next_index_str:
continue

if current_access_index.value > AccessIndex.TID.value or next_access_index.value > AccessIndex.TID.value:
return False

Expand Down Expand Up @@ -377,7 +394,7 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation]

return fused_ops

def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]], info_1: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]) -> Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]]:
def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]], info_1: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]]) -> Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Fuse the safety info of two separate operations
Expand All @@ -386,13 +403,13 @@ def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, Acce
:returns: the fused safety info
"""

fused_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode]] = {}
fused_info: Dict[Tuple[int, int], Tuple[AccessIndex, AccessMode, str]] = {}
for key, value in info_0.items():
if key not in info_1:
fused_info[key] = value
else:
other_index, other_mode = info_1[key]
current_index, current_mode = value
other_index, other_mode, other_index_str = info_1[key]
current_index, current_mode, current_index_str = value

index_to_set: AccessIndex
mode_to_set: AccessMode
Expand All @@ -407,7 +424,7 @@ def fuse_safety_info(self, info_0: Dict[Tuple[int, int], Tuple[AccessIndex, Acce
else:
mode_to_set = AccessMode.ReadWrite

fused_info[key] = (index_to_set, mode_to_set)
fused_info[key] = (index_to_set, mode_to_set, other_index_str)

for key, value in info_1.items():
# Already handled in the previous loop
Expand Down Expand Up @@ -442,7 +459,7 @@ def fuse_operations(self, operations: List[TracerOperation], fused_safety_info:
parsers: List[Parser] = []
args: Dict[str, Dict[str, Any]] = {}
dependencies: Set[DataDependency] = set()
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode]] = {}
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}

for index, op in enumerate(operations):
assert isinstance(op.policy, RangePolicy) and policy.begin == op.policy.begin and policy.end == op.policy.end
Expand All @@ -463,12 +480,13 @@ def fuse_operations(self, operations: List[TracerOperation], fused_safety_info:

return TracerOperation(None, future, fused_name, policy, workunits, operation, parsers, fused_name, args, dependencies, fused_safety_info)

def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef, cache_key: Tuple[str, str]) -> Tuple[Set[DataDependency], Dict[str, AccessMode]]:
"""
Get the data dependencies of an operation from its input arguments
:param kwargs: the keyword arguments passed to the workunit
:param AST: the AST of the input workunit
:param cache_key: the key used to cache the results of traversing the AST
:returns: the set of data dependencies and the access modes of the views
"""

Expand All @@ -489,7 +507,12 @@ def get_data_dependencies(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) ->
if isinstance(value, ViewType):
view_args.add(arg)

access_modes: Dict[str, AccessMode] = get_view_access_modes(AST, view_args)
access_modes: Dict[str, AccessMode]
if cache_key in self.access_modes_cache:
access_modes = self.access_modes_cache[cache_key]
else:
access_modes = get_view_access_modes(AST, view_args)
self.access_modes_cache[cache_key] = access_modes

# Second pass to check if the views are dependencies
for arg, value in kwargs.items():
Expand Down
2 changes: 0 additions & 2 deletions pykokkos/core/type_inference/args_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,6 @@ def get_type_info(
is_missing_annotations: bool = (
workunit_str in ORIGINAL_PARAMS
or
list_passed
or
check_missing_annotations(this_tree.args.args)
)

Expand Down

0 comments on commit 359b08a

Please sign in to comment.