Skip to content

Commit

Permalink
[BUG] Fixing memory issue encountered while compiling the model sam (
Browse files Browse the repository at this point in the history
…#466)

Closes #326
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 19, 2024
1 parent cb07596 commit 3c8c922
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/hidet/graph/ops/matmul/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
if a.dtype.nbytes > 4 or b.dtype.nbytes > 4:
return None

# If either a or b is a tensor with actual storage(i.e., not symbolic),
# the operator `flatten` calls in the code below will
# require additional memory and somehow these allocated spaces are not released during the model compilation.
# This causes the error described in issue #326.

can_imperative = hidet.option.get_imperative()
if a.is_symbolic() or b.is_symbolic():
hidet.option.imperative(False)

if len(a.shape) == 1: # shape: [a]
a = a.unsqueeze([0, 1]) # [1, 1, a]
if len(b.shape) == 2: # shape: [a, b]
Expand Down Expand Up @@ -170,6 +179,8 @@ def resolve_generic(self, op: Operator) -> Optional[List[Tensor]]:
b = flatten(broadcast(b, b_broadcast_shape), start_dim=0, end_dim=-3)
c = self.run_batch_matmul(a, b)
c = c.reshape(c_shape)

hidet.option.imperative(can_imperative)
return [c]

def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
Expand Down

0 comments on commit 3c8c922

Please sign in to comment.