From 9152efd90734fa86323c0767cfbbe83e85319d3c Mon Sep 17 00:00:00 2001 From: Evelina <10428420+ekmb@users.noreply.github.com> Date: Wed, 19 Jul 2023 12:40:02 -0700 Subject: [PATCH] fix pos id - hf update (#7075) * fix pos id - hf update Signed-off-by: Evelina * add missing import Signed-off-by: Evelina * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evelina Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/collections/nlp/models/nlp_model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 032a7449c27e4..d739efa884851 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -16,7 +16,7 @@ import hashlib import json import os -from typing import Any, Optional +from typing import Any, Mapping, Optional from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer @@ -385,3 +385,13 @@ def load_from_checkpoint( finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + # starting with trasformers v4.31.0, buffer for position_ids is persistent=False + if ( + self.bert_model is not None + and "position_ids" not in self.bert_model.embeddings._modules + and "bert_model.embeddings.position_ids" in state_dict + ): + del state_dict["bert_model.embeddings.position_ids"] + super(NLPModel, self).load_state_dict(state_dict, strict=strict)