Skip to content

Commit

Permalink
Add option to enforce alignment constraint when planning memory (#8003)
Browse files Browse the repository at this point in the history
Summary:

Add ability to enforce start alignment in Cadence Memory Planner

Reviewed By: zonglinpeng

Differential Revision: D68762973
  • Loading branch information
Eashan Garg authored and facebook-github-bot committed Feb 5, 2025
1 parent 3e64c4e commit 4fdd28a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 27 deletions.
2 changes: 2 additions & 0 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord(
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
mem_alignment: int = 1,
) -> ExecutorchProgramManager:
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
Expand All @@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord(
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
mem_alignment=mem_alignment,
)

# Get executorch program after Cadence specific passes
Expand Down
31 changes: 23 additions & 8 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import collections
import itertools
import logging
import math
import typing
from functools import partial
from typing import Iterable, List, Optional, Tuple
Expand Down Expand Up @@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
return memory_config.memory_sizes[exir_id - 1]


def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
return int(math.ceil(pre_aligned_offset / alignment) * alignment)


def collect_specs_from_graph_module(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
Expand Down Expand Up @@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
return None

def memory_available(spec: TensorSpec) -> bool:
return spec.mem_offset + spec.allocated_memory <= get_size(
memory_config, spec.mem_id
)
return get_aligned_offset(
spec.mem_offset + spec.allocated_memory, alignment
) <= get_size(memory_config, spec.mem_id)

# Iterate over all the specs in sorted order
for spec in sorted(
Expand All @@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool:
continue
spec.mem_offset = 0
while memory_available(spec) and (overlapped := overlap(spec)):
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
spec.mem_offset = get_aligned_offset(
overlapped.mem_offset + overlapped.allocated_memory, alignment
)
if memory_available(spec):
allocated_buffers[spec.mem_id].append(spec)
bufsizes[spec.mem_id] = max(
Expand Down Expand Up @@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
# to the naive one, reusing 0 tensor. The paper may have a typo here.
prev_offset = max(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
get_aligned_offset(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
alignment,
),
prev_offset,
)
if spec.mem_offset is None:
if prev_offset + spec.allocated_memory > get_size(
memory_config, spec.mem_id
):
if get_aligned_offset(
prev_offset + spec.allocated_memory, alignment
) > get_size(memory_config, spec.mem_id):
continue
else:
spec.mem_offset = prev_offset
Expand Down Expand Up @@ -423,6 +433,7 @@ def __init__(
]
]
] = None,
mem_alignment: int = 1,
) -> None:
self._init_mem_algos()

Expand All @@ -433,6 +444,9 @@ def __init__(
self.alloc_graph_output = alloc_graph_output
self.additional_constraint_gen_passes = additional_constraint_gen_passes

assert mem_alignment > 0, "mem_alignment must be positive"
self.mem_alignment = mem_alignment

def _init_mem_algos(self) -> None:
self.available_mem_algos = [
position_based_greedy_with_hierarchy,
Expand All @@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
alignment=self.mem_alignment,
)
mem_planning(graph_module)

Expand Down
78 changes: 59 additions & 19 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict

import logging
import math
import unittest
Expand All @@ -12,11 +14,12 @@
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir import memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import collect_specs_from_nodes
from executorch.exir.tests.models import MultiLayerPerceptron


class TestMemPlanningPasses(unittest.TestCase):
def test_calculate_peak_memory_pass(self):
def test_calculate_peak_memory_pass(self) -> None:
class PeakMemoryTestModel(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
Expand All @@ -30,7 +33,7 @@ def forward(self, x: torch.Tensor):
x = self.linear2(x)
return x

def calculate_aligned_num_bytes(num: int, alignment: int = 16):
def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
return math.ceil(num / alignment) * alignment

# model 1
Expand Down Expand Up @@ -84,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
) # Align data on a 16 byte boundary
self.assertEqual(peak_usage, expected_peak_usage)

def test_zero_memory_pass(self):
def test_zero_memory_pass(self) -> None:
class ZeroMem(torch.nn.Module):
def forward(self, x):
return x[:, 2::3, ...]
Expand Down Expand Up @@ -186,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
f"{spec=} {arg_spec=}",
)

def verify_nop_memory_alloc(self, graph_module):
def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._cat_nop.out
):
Expand All @@ -202,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
):
self._verify_select_nop_memory_alloc(node)

def test_optimize_cat_on_placeholders(self):
def test_optimize_cat_on_placeholders(self) -> None:
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y))
Expand All @@ -226,7 +229,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_outermost(self):
def test_optimize_cat_outermost(self) -> None:
class OptimizeCatFeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -253,7 +256,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_non_outermost(self):
def test_optimize_cat_non_outermost(self) -> None:
class OptimizeCatFeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -280,7 +283,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost(self):
def test_no_optimize_cat_non_outermost(self) -> None:
class OptimizeCatInfeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -306,7 +309,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost1(self):
def test_no_optimize_cat_non_outermost1(self) -> None:
class OptimizeCatInfeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -333,7 +336,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice(self):
def test_optimize_cat_with_slice(self) -> None:
class OptimizeCatSliceFeasible(torch.nn.Module):
def forward(self, x):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -362,7 +365,7 @@ def forward(self, x):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice_infeasible(self):
def test_optimize_cat_with_slice_infeasible(self) -> None:
class OptimizeCatSliceInfeasible(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -388,7 +391,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_Tensor(self):
def test_optimize_slice_Tensor(self) -> None:
class SliceTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -450,7 +453,7 @@ def forward(self, x, y, z):
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_Tensor(self):
def test_optimize_select_Tensor(self) -> None:
class SelectTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -517,7 +520,7 @@ def forward(self, x, y, z):

# TODO: Test fails due to memory planning
@unittest.expectedFailure
def test_optimize_cat_with_param(self):
def test_optimize_cat_with_param(self) -> None:
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand Down Expand Up @@ -545,7 +548,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_then_slice_on_mutable_buffer(self):
def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None:
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand All @@ -570,7 +573,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_view(self):
def test_optimize_cat_with_view(self) -> None:
class CatViewFeasible(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -597,7 +600,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_repeated_args(self):
def test_no_optimize_cat_with_repeated_args(self) -> None:
class CatViewInfeasible(torch.nn.Module):
def forward(self, x):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -621,7 +624,7 @@ def forward(self, x):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_placeholder(self):
def test_no_optimize_cat_with_placeholder(self) -> None:
class CatViewInfeasible(torch.nn.Module):
def forward(self, x, y):
# Repeat will be decomposed into a cat. The cat cannot be optimized
Expand Down Expand Up @@ -739,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_view_for_unallocated_output(self):
def test_view_for_unallocated_output(self) -> None:
class Model(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand All @@ -762,3 +765,40 @@ def forward(self, x, y):
)
self.assertEqual(count_node(graph_module, memory.view), 1)
self.verify_nop_memory_alloc(graph_module)

def test_start_alignment_constraints(self) -> None:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
add_0 = torch.add(x, y)
add_1 = torch.add(x, add_0)
add_2 = torch.add(add_0, add_1)
add_3 = torch.add(add_1, add_2)
return add_3

model = Model()
inputs = (torch.randn(4, 17), torch.randn(4, 17))
for mem_algo in range(0, 2):
graph_module = (
compiler.export_to_executorch_gen_etrecord(
model,
inputs,
opt_level=1,
mem_algo=mem_algo,
alloc_graph_input=False,
alloc_graph_output=False,
mem_alignment=37,
)
.exported_program()
.graph_module
)
# Assert that all memory allocations are aligned to 32B start address
for spec in collect_specs_from_nodes(
graph_module.graph.nodes,
ignore_graph_input=True,
ignore_graph_output=True,
):
if spec and spec.mem_offset:
self.assertEqual(spec.mem_offset % 37, 0)

0 comments on commit 4fdd28a

Please sign in to comment.