44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import itertools
78import logging
89import warnings
10+ from dataclasses import dataclass , field
911from functools import partial
10- from typing import Any , Callable , List , Optional
12+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
1113
1214import torch
1315from executorch .exir ._warnings import deprecated
1618from executorch .exir .memory_planning import (
1719 _is_out_var_node ,
1820 apply_algo ,
21+ collect_specs_from_nodes ,
22+ filter_nodes ,
1923 get_node_tensor_specs ,
2024 MemoryPlanningAlgorithmSuite ,
2125 Verifier ,
2226)
2327from executorch .exir .operator .convert import get_out_args_from_opoverload
2428from executorch .exir .pass_base import PassBase , PassResult
25- from executorch .exir .tensor import ALIGNMENT
29+ from executorch .exir .tensor import ALIGNMENT , TensorSpec
30+ from torch import fx
2631from torch .export .exported_program import ExportGraphSignature
32+ from torch .fx import Node
2733
2834
2935# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
@@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
3743 return str (any_callable )
3844
3945
46+ def _is_buffer (
47+ node : Node , graph_signature : ExportGraphSignature
48+ ) -> Tuple [bool , Optional [str ]]:
49+ """
50+ Check if the node is buffer according to the provided graph signature.
51+ If it is one return its fqn as well
52+ """
53+ if node .op == "placeholder" :
54+ if isinstance (node .target , str ):
55+ if node .target in graph_signature .inputs_to_buffers :
56+ fqn = graph_signature .inputs_to_buffers [node .target ]
57+ return (True , fqn )
58+ return (False , None )
59+
60+
61+ def _is_mutable_buffer (
62+ node : Node , graph_signature : ExportGraphSignature
63+ ) -> Tuple [bool , Optional [str ]]:
64+ """
65+ Check if the node is mutable buffer according to the provided graph signature.
66+ If it is one return its fqn as well
67+ """
68+ if node .op == "placeholder" :
69+ if isinstance (node .target , str ):
70+ if node .target in graph_signature .inputs_to_buffers :
71+ fqn = graph_signature .inputs_to_buffers [node .target ]
72+ # if the buffer is mutated then record that
73+ if fqn in graph_signature .buffers_to_mutate .values ():
74+ return True , fqn
75+ return False , None
76+
77+
78+ def _get_spec_from_node (node : fx .Node ) -> TensorSpec :
79+ specs = get_node_tensor_specs (node )
80+ return specs [0 ]
81+
82+
83+ def _insert_mutable_buffer_specs (
84+ state : "_MemoryPlanningState" , gm : torch .fx .GraphModule , gs : ExportGraphSignature
85+ ):
86+ for node in gm .graph .nodes :
87+ is_mutable , fqn = _is_mutable_buffer (node , gs )
88+ if is_mutable :
89+ assert fqn
90+ spec = _get_spec_from_node (node )
91+ if (
92+ getattr (spec , "mem_id" , None ) is not None
93+ or getattr (spec , "mem_offset" , None ) is not None
94+ ):
95+ raise ValueError (
96+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
97+ )
98+ if fqn not in state .mutable_buffers .keys ():
99+ state .mutable_buffers [fqn ] = set ()
100+ state .mutable_buffers [fqn ].add (spec )
101+ continue
102+ is_buffer , fqn = _is_buffer (node , gs )
103+ # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state()
104+ # So cache it and later double check that this buffer never appears mutable
105+ if is_buffer :
106+ assert fqn
107+ spec = _get_spec_from_node (node )
108+ if (
109+ getattr (spec , "mem_id" , None ) is not None
110+ or getattr (spec , "mem_offset" , None ) is not None
111+ ):
112+ raise ValueError (
113+ "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned"
114+ )
115+ if fqn not in state .maybe_mutable_buffers .keys ():
116+ state .maybe_mutable_buffers [fqn ] = set ()
117+ state .maybe_mutable_buffers [fqn ].add (spec )
118+
119+
120+ def _check_default_mem_ids (gm : torch .fx .GraphModule ):
121+ for node in gm .graph .nodes :
122+ for spec in collect_specs_from_nodes (
123+ filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
124+ None ,
125+ ignore_graph_input = False ,
126+ ignore_const = False ,
127+ ignore_out_var_node = False ,
128+ dedup = False ,
129+ do_assertion = False ,
130+ ignore_dynamic_unbound_tensor = False ,
131+ ):
132+ mem_id = getattr (spec , "mem_id" , None )
133+ if mem_id is not None and mem_id != 1 :
134+ raise ValueError (
135+ "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1"
136+ )
137+
138+
139+ @dataclass
140+ class _MemoryPlanningState :
141+ mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
142+ maybe_mutable_buffers : Dict [str , Set [TensorSpec ]] = field (default_factory = dict )
143+ graph_modules : List [torch .fx .GraphModule ] = field (default_factory = list )
144+
145+
40146class MemoryPlanningPass (PassBase ):
41147 def __init__ (
42148 self ,
@@ -45,6 +151,7 @@ def __init__(
45151 alloc_graph_input : bool = True ,
46152 alloc_graph_output : bool = True ,
47153 alloc_mutable_buffers : bool = True ,
154+ share_mutable_buffers : bool = False ,
48155 alignment : int = ALIGNMENT ,
49156 ) -> None :
50157 r"""
@@ -55,12 +162,18 @@ def __init__(
55162 """
56163 if memory_planning_algo is None :
57164 memory_planning_algo = MemoryPlanningAlgorithmSuite ()
165+ if share_mutable_buffers and not alloc_mutable_buffers :
166+ raise ValueError (
167+ "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True"
168+ )
58169 self .memory_planning_algo : Callable [..., List [int ]] = memory_planning_algo
59170 self .allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
60171 self .alloc_graph_input = alloc_graph_input
61172 self .alloc_graph_output = alloc_graph_output
62173 self .alloc_mutable_buffers = alloc_mutable_buffers
174+ self .share_mutable_buffers = share_mutable_buffers
63175 self .alignment = alignment
176+ self .state = _MemoryPlanningState ()
64177
65178 def _set_alloc_node_spec (self , graph_module : torch .fx .GraphModule ) -> None :
66179 """
@@ -134,9 +247,17 @@ def run(
134247 graph_signature ,
135248 self .alloc_graph_input ,
136249 self .alloc_graph_output ,
137- self .alloc_mutable_buffers ,
250+ # If we are sharing the mutable buffers then do not allocate them in
251+ # memory planning algo, instead collect all of the specs over all the entry
252+ # points and then allocate them directly in the run_multimethod name call
253+ self .alloc_mutable_buffers and not self .share_mutable_buffers ,
138254 )
139255
256+ if self .share_mutable_buffers and graph_signature is not None :
257+ self .state .graph_modules .append (graph_module )
258+ _check_default_mem_ids (graph_module )
259+ _insert_mutable_buffer_specs (self .state , graph_module , graph_signature )
260+
140261 # TODO: make the verifier do the work recursively to handle
141262 # control flow
142263 verifier = Verifier (
@@ -164,3 +285,31 @@ def run(
164285 # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
165286 verifier .verify_storage_reuse ()
166287 return PassResult (graph_module , True )
288+
289+ def run_multimethod (self ):
290+ "Resolve any memory planning done across entry points"
291+ if self .share_mutable_buffers :
292+ arena : int = 0
293+
294+ # Every spec that shares an fqn is the same tensor! So we give it the same id and offset
295+ # anywhere it appears.
296+ for fqn , specs_set in self .state .mutable_buffers .items ():
297+ specs = list (specs_set )
298+ # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable.
299+ if fqn in self .state .maybe_mutable_buffers .keys ():
300+ specs .extend (self .state .maybe_mutable_buffers [fqn ])
301+ for spec in specs :
302+ # Assume a default memory planning placed all activations on 1, place shared state on 2.
303+ spec .mem_id = 2
304+ spec .realign (self .alignment )
305+ # State is persistent, so the memory never overlaps.
306+ spec .mem_offset = arena
307+ # They should all be the same size since they are the same tensor, so just bump off the first.
308+ arena += specs [0 ].allocated_memory
309+
310+ for graph_module in self .state .graph_modules :
311+ if len (graph_module .meta ["non_const_buffer_sizes" ]) != 2 :
312+ raise ValueError (
313+ "Cannot share mutable state if not using default memory ids"
314+ )
315+ graph_module .meta ["non_const_buffer_sizes" ].append (arena )
0 commit comments