Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to enforce alignment constraint when planning memory #8003

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord(
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
mem_alignment: int = 1,
) -> ExecutorchProgramManager:
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
Expand All @@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord(
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
mem_alignment=mem_alignment,
)

# Get executorch program after Cadence specific passes
Expand Down
31 changes: 23 additions & 8 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import collections
import itertools
import logging
import math
import typing
from functools import partial
from typing import Iterable, List, Optional, Tuple
Expand Down Expand Up @@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
return memory_config.memory_sizes[exir_id - 1]


def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
return int(math.ceil(pre_aligned_offset / alignment) * alignment)


def collect_specs_from_graph_module(
graph_module: torch.fx.GraphModule,
alloc_graph_input: bool,
Expand Down Expand Up @@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
return None

def memory_available(spec: TensorSpec) -> bool:
return spec.mem_offset + spec.allocated_memory <= get_size(
memory_config, spec.mem_id
)
return get_aligned_offset(
spec.mem_offset + spec.allocated_memory, alignment
) <= get_size(memory_config, spec.mem_id)

# Iterate over all the specs in sorted order
for spec in sorted(
Expand All @@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool:
continue
spec.mem_offset = 0
while memory_available(spec) and (overlapped := overlap(spec)):
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
spec.mem_offset = get_aligned_offset(
overlapped.mem_offset + overlapped.allocated_memory, alignment
)
if memory_available(spec):
allocated_buffers[spec.mem_id].append(spec)
bufsizes[spec.mem_id] = max(
Expand Down Expand Up @@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
# to the naive one, reusing 0 tensor. The paper may have a typo here.
prev_offset = max(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
get_aligned_offset(
allocated_spec.mem_offset + allocated_spec.allocated_memory,
alignment,
),
prev_offset,
)
if spec.mem_offset is None:
if prev_offset + spec.allocated_memory > get_size(
memory_config, spec.mem_id
):
if get_aligned_offset(
prev_offset + spec.allocated_memory, alignment
) > get_size(memory_config, spec.mem_id):
continue
else:
spec.mem_offset = prev_offset
Expand Down Expand Up @@ -423,6 +433,7 @@ def __init__(
]
]
] = None,
mem_alignment: int = 1,
) -> None:
self._init_mem_algos()

Expand All @@ -433,6 +444,9 @@ def __init__(
self.alloc_graph_output = alloc_graph_output
self.additional_constraint_gen_passes = additional_constraint_gen_passes

assert mem_alignment > 0, "mem_alignment must be positive"
self.mem_alignment = mem_alignment

def _init_mem_algos(self) -> None:
self.available_mem_algos = [
position_based_greedy_with_hierarchy,
Expand All @@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
alloc_graph_input=self.alloc_graph_input,
alloc_graph_output=self.alloc_graph_output,
alignment=self.mem_alignment,
)
mem_planning(graph_module)

Expand Down
76 changes: 57 additions & 19 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir import memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import collect_specs_from_nodes
from executorch.exir.tests.models import MultiLayerPerceptron


class TestMemPlanningPasses(unittest.TestCase):
def test_calculate_peak_memory_pass(self):
def test_calculate_peak_memory_pass(self) -> None:
class PeakMemoryTestModel(torch.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
Expand All @@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
x = self.linear2(x)
return x

def calculate_aligned_num_bytes(num: int, alignment: int = 16):
def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
return math.ceil(num / alignment) * alignment

# model 1
Expand Down Expand Up @@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
) # Align data on a 16 byte boundary
self.assertEqual(peak_usage, expected_peak_usage)

def test_zero_memory_pass(self):
def test_zero_memory_pass(self) -> None:
class ZeroMem(torch.nn.Module):
def forward(self, x):
return x[:, 2::3, ...]
Expand Down Expand Up @@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
f"{spec=} {arg_spec=}",
)

def verify_nop_memory_alloc(self, graph_module):
def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._cat_nop.out
):
Expand All @@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
):
self._verify_select_nop_memory_alloc(node)

def test_optimize_cat_on_placeholders(self):
def test_optimize_cat_on_placeholders(self) -> None:
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y))
Expand All @@ -228,7 +229,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_outermost(self):
def test_optimize_cat_outermost(self) -> None:
class OptimizeCatFeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -255,7 +256,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_non_outermost(self):
def test_optimize_cat_non_outermost(self) -> None:
class OptimizeCatFeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -282,7 +283,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost(self):
def test_no_optimize_cat_non_outermost(self) -> None:
class OptimizeCatInfeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -308,7 +309,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost1(self):
def test_no_optimize_cat_non_outermost1(self) -> None:
class OptimizeCatInfeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -335,7 +336,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice(self):
def test_optimize_cat_with_slice(self) -> None:
class OptimizeCatSliceFeasible(torch.nn.Module):
def forward(self, x):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -364,7 +365,7 @@ def forward(self, x):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice_infeasible(self):
def test_optimize_cat_with_slice_infeasible(self) -> None:
class OptimizeCatSliceInfeasible(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -390,7 +391,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_Tensor(self):
def test_optimize_slice_Tensor(self) -> None:
class SliceTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -452,7 +453,7 @@ def forward(self, x, y, z):
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_Tensor(self):
def test_optimize_select_Tensor(self) -> None:
class SelectTensor(torch.nn.Module):
def forward(self, x, y, z):
x1 = torch.add(x, 2.4, 3.1)
Expand Down Expand Up @@ -519,7 +520,7 @@ def forward(self, x, y, z):

# TODO: Test fails due to memory planning
@unittest.expectedFailure
def test_optimize_cat_with_param(self):
def test_optimize_cat_with_param(self) -> None:
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand Down Expand Up @@ -547,7 +548,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_then_slice_on_mutable_buffer(self):
def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None:
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand All @@ -572,7 +573,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_view(self):
def test_optimize_cat_with_view(self) -> None:
class CatViewFeasible(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -599,7 +600,7 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_repeated_args(self):
def test_no_optimize_cat_with_repeated_args(self) -> None:
class CatViewInfeasible(torch.nn.Module):
def forward(self, x):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -623,7 +624,7 @@ def forward(self, x):
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_placeholder(self):
def test_no_optimize_cat_with_placeholder(self) -> None:
class CatViewInfeasible(torch.nn.Module):
def forward(self, x, y):
# Repeat will be decomposed into a cat. The cat cannot be optimized
Expand Down Expand Up @@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_view_for_unallocated_output(self):
def test_view_for_unallocated_output(self) -> None:
class Model(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
Expand All @@ -764,3 +765,40 @@ def forward(self, x, y):
)
self.assertEqual(count_node(graph_module, memory.view), 1)
self.verify_nop_memory_alloc(graph_module)

def test_start_alignment_constraints(self) -> None:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor):
add_0 = torch.add(x, y)
add_1 = torch.add(x, add_0)
add_2 = torch.add(add_0, add_1)
add_3 = torch.add(add_1, add_2)
return add_3

model = Model()
inputs = (torch.randn(4, 17), torch.randn(4, 17))
for mem_algo in range(0, 2):
graph_module = (
compiler.export_to_executorch_gen_etrecord(
model,
inputs,
opt_level=1,
mem_algo=mem_algo,
alloc_graph_input=False,
alloc_graph_output=False,
mem_alignment=37,
)
.exported_program()
.graph_module
)
# Assert that all memory allocations are aligned to 32B start address
for spec in collect_specs_from_nodes(
graph_module.graph.nodes,
ignore_graph_input=True,
ignore_graph_output=True,
):
if spec and spec.mem_offset:
self.assertEqual(spec.mem_offset % 37, 0)
Loading