diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index f9abe1c542..b96a0c7ad3 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord( alloc_graph_output: bool = True, memory_config: Optional[MemoryConfig] = None, dump_graphs: bool = False, + mem_alignment: int = 1, ) -> ExecutorchProgramManager: cadence_passes = get_cadence_passes(opt_level) edge_prog_manager = export_to_edge(model, inputs, dump_graphs) @@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord( mem_algo=mem_algo, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, + mem_alignment=mem_alignment, ) # Get executorch program after Cadence specific passes diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 787ea4fa51..77ae7eb799 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -9,6 +9,7 @@ import collections import itertools import logging +import math import typing from functools import partial from typing import Iterable, List, Optional, Tuple @@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int: return memory_config.memory_sizes[exir_id - 1] +def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int: + return int(math.ceil(pre_aligned_offset / alignment) * alignment) + + def collect_specs_from_graph_module( graph_module: torch.fx.GraphModule, alloc_graph_input: bool, @@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]: return None def memory_available(spec: TensorSpec) -> bool: - return spec.mem_offset + spec.allocated_memory <= get_size( - memory_config, spec.mem_id - ) + return get_aligned_offset( + spec.mem_offset + spec.allocated_memory, alignment + ) <= get_size(memory_config, spec.mem_id) # Iterate over all the specs in sorted order for spec in sorted( @@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool: continue spec.mem_offset = 0 while memory_available(spec) and (overlapped := overlap(spec)): - spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory + spec.mem_offset = get_aligned_offset( + overlapped.mem_offset + overlapped.allocated_memory, alignment + ) if memory_available(spec): allocated_buffers[spec.mem_id].append(spec) bufsizes[spec.mem_id] = max( @@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy( # calculation of gap incorrect. Moving it out will make the algorithm degenerate # to the naive one, reusing 0 tensor. The paper may have a typo here. prev_offset = max( - allocated_spec.mem_offset + allocated_spec.allocated_memory, + get_aligned_offset( + allocated_spec.mem_offset + allocated_spec.allocated_memory, + alignment, + ), prev_offset, ) if spec.mem_offset is None: - if prev_offset + spec.allocated_memory > get_size( - memory_config, spec.mem_id - ): + if get_aligned_offset( + prev_offset + spec.allocated_memory, alignment + ) > get_size(memory_config, spec.mem_id): continue else: spec.mem_offset = prev_offset @@ -423,6 +433,7 @@ def __init__( ] ] ] = None, + mem_alignment: int = 1, ) -> None: self._init_mem_algos() @@ -433,6 +444,9 @@ def __init__( self.alloc_graph_output = alloc_graph_output self.additional_constraint_gen_passes = additional_constraint_gen_passes + assert mem_alignment > 0, "mem_alignment must be positive" + self.mem_alignment = mem_alignment + def _init_mem_algos(self) -> None: self.available_mem_algos = [ position_based_greedy_with_hierarchy, @@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: allow_lifetime_and_storage_overlap=(self.opt_level >= 2), alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, + alignment=self.mem_alignment, ) mem_planning(graph_module) diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 53f173b4ae..2800e7e773 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -14,11 +14,12 @@ from executorch.backends.cadence.aot.pass_utils import count_node from executorch.exir import memory from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.memory_planning import collect_specs_from_nodes from executorch.exir.tests.models import MultiLayerPerceptron class TestMemPlanningPasses(unittest.TestCase): - def test_calculate_peak_memory_pass(self): + def test_calculate_peak_memory_pass(self) -> None: class PeakMemoryTestModel(torch.nn.Module): def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): super().__init__() @@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor): x = self.linear2(x) return x - def calculate_aligned_num_bytes(num: int, alignment: int = 16): + def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int: return math.ceil(num / alignment) * alignment # model 1 @@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16): ) # Align data on a 16 byte boundary self.assertEqual(peak_usage, expected_peak_usage) - def test_zero_memory_pass(self): + def test_zero_memory_pass(self) -> None: class ZeroMem(torch.nn.Module): def forward(self, x): return x[:, 2::3, ...] @@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None: f"{spec=} {arg_spec=}", ) - def verify_nop_memory_alloc(self, graph_module): + def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.find_nodes( op="call_function", target=torch.ops.aten._cat_nop.out ): @@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module): ): self._verify_select_nop_memory_alloc(node) - def test_optimize_cat_on_placeholders(self): + def test_optimize_cat_on_placeholders(self) -> None: class Cat(torch.nn.Module): def forward(self, x, y): return torch.ops.aten.cat((x, y)) @@ -228,7 +229,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_outermost(self): + def test_optimize_cat_outermost(self) -> None: class OptimizeCatFeasible1(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -255,7 +256,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_non_outermost(self): + def test_optimize_cat_non_outermost(self) -> None: class OptimizeCatFeasible2(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -282,7 +283,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_non_outermost(self): + def test_no_optimize_cat_non_outermost(self) -> None: class OptimizeCatInfeasible1(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -308,7 +309,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_non_outermost1(self): + def test_no_optimize_cat_non_outermost1(self) -> None: class OptimizeCatInfeasible2(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -335,7 +336,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_with_slice(self): + def test_optimize_cat_with_slice(self) -> None: class OptimizeCatSliceFeasible(torch.nn.Module): def forward(self, x): x1 = torch.add(x, 2.4, 3.1) @@ -364,7 +365,7 @@ def forward(self, x): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_with_slice_infeasible(self): + def test_optimize_cat_with_slice_infeasible(self) -> None: class OptimizeCatSliceInfeasible(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -390,7 +391,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_slice_Tensor(self): + def test_optimize_slice_Tensor(self) -> None: class SliceTensor(torch.nn.Module): def forward(self, x, y, z): x1 = torch.add(x, 2.4, 3.1) @@ -452,7 +453,7 @@ def forward(self, x, y, z): ) self.verify_nop_memory_alloc(graph_module) - def test_optimize_select_Tensor(self): + def test_optimize_select_Tensor(self) -> None: class SelectTensor(torch.nn.Module): def forward(self, x, y, z): x1 = torch.add(x, 2.4, 3.1) @@ -519,7 +520,7 @@ def forward(self, x, y, z): # TODO: Test fails due to memory planning @unittest.expectedFailure - def test_optimize_cat_with_param(self): + def test_optimize_cat_with_param(self) -> None: class CatWithPadding(torch.nn.Module): def __init__(self, padding_shape): super().__init__() @@ -547,7 +548,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_then_slice_on_mutable_buffer(self): + def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None: class CatWithPadding(torch.nn.Module): def __init__(self, padding_shape): super().__init__() @@ -572,7 +573,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_with_view(self): + def test_optimize_cat_with_view(self) -> None: class CatViewFeasible(torch.nn.Module): def forward(self, x, y): x1 = torch.add(x, 2.4, 3.1) @@ -599,7 +600,7 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_with_repeated_args(self): + def test_no_optimize_cat_with_repeated_args(self) -> None: class CatViewInfeasible(torch.nn.Module): def forward(self, x): x1 = torch.add(x, 2.4, 3.1) @@ -623,7 +624,7 @@ def forward(self, x): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_with_placeholder(self): + def test_no_optimize_cat_with_placeholder(self) -> None: class CatViewInfeasible(torch.nn.Module): def forward(self, x, y): # Repeat will be decomposed into a cat. The cat cannot be optimized @@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor: self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) self.verify_nop_memory_alloc(graph_module) - def test_view_for_unallocated_output(self): + def test_view_for_unallocated_output(self) -> None: class Model(torch.nn.Module): def __init__(self, padding_shape): super().__init__() @@ -764,3 +765,40 @@ def forward(self, x, y): ) self.assertEqual(count_node(graph_module, memory.view), 1) self.verify_nop_memory_alloc(graph_module) + + def test_start_alignment_constraints(self) -> None: + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + add_0 = torch.add(x, y) + add_1 = torch.add(x, add_0) + add_2 = torch.add(add_0, add_1) + add_3 = torch.add(add_1, add_2) + return add_3 + + model = Model() + inputs = (torch.randn(4, 17), torch.randn(4, 17)) + for mem_algo in range(0, 2): + graph_module = ( + compiler.export_to_executorch_gen_etrecord( + model, + inputs, + opt_level=1, + mem_algo=mem_algo, + alloc_graph_input=False, + alloc_graph_output=False, + mem_alignment=37, + ) + .exported_program() + .graph_module + ) + # Assert that all memory allocations are aligned to 32B start address + for spec in collect_specs_from_nodes( + graph_module.graph.nodes, + ignore_graph_input=True, + ignore_graph_output=True, + ): + if spec and spec.mem_offset: + self.assertEqual(spec.mem_offset % 37, 0)