diff --git a/timm/models/eva.py b/timm/models/eva.py index 552965947b..2ac988ea0c 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -677,7 +677,7 @@ def forward_features(self, x): x, rot_pos_embed = self._pos_embed(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk, x, rope=rot_pos_embed) + x = checkpoint(blk, x, rope=rot_pos_embed, use_reentrant=False) else: x = blk(x, rope=rot_pos_embed) x = self.norm(x)