From c6959747cdbcb2f056a7229b2aa3c5de3476d1e6 Mon Sep 17 00:00:00 2001 From: Bolin Sun Date: Fri, 20 Sep 2024 16:43:27 -0400 Subject: [PATCH] [BUG] Fixing memory issue encountered while compiling the model `sam` (#466) Closes #326 --- python/hidet/graph/ops/matmul/resolve.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/hidet/graph/ops/matmul/resolve.py b/python/hidet/graph/ops/matmul/resolve.py index 3b878462b..a724237f9 100644 --- a/python/hidet/graph/ops/matmul/resolve.py +++ b/python/hidet/graph/ops/matmul/resolve.py @@ -133,6 +133,15 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: if a.dtype.nbytes > 4 or b.dtype.nbytes > 4: return None + # If either a or b is a tensor with actual storage(i.e., not symbolic), + # the operator `flatten` calls in the code below will + # require additional memory and somehow these allocated spaces are not released during the model compilation. + # This causes the error described in issue #326. + + can_imperative = hidet.option.get_imperative() + if a.is_symbolic() or b.is_symbolic(): + hidet.option.imperative(False) + if len(a.shape) == 1: # shape: [a] a = a.unsqueeze([0, 1]) # [1, 1, a] if len(b.shape) == 2: # shape: [a, b] @@ -170,6 +179,8 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]: b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3) c = self.run_batch_matmul(a, b) c = c.reshape(c_shape) + + hidet.option.imperative(can_imperative) return [c] def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]: