Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About Loading Weights to text Encoder #44

Open
sushil579 opened this issue Aug 1, 2024 · 1 comment
Open

About Loading Weights to text Encoder #44

sushil579 opened this issue Aug 1, 2024 · 1 comment

Comments

@sushil579
Copy link

Hey,

I might be confused here:

you load the model in build.vit()

self.vision_encoder = self.build_vision_encoder() # basically here

and then initialize the text encoder
self.text_encoder = self.build_text_encoder()

so basically loading the pretrained BERT

My question would be shoudnt you guys load the text encoder weights from the model_path(like b16_17m.pth)?

@Andy1621
Copy link
Collaborator

Andy1621 commented Aug 2, 2024

Hi! Both the vision and text encoders' weights will be updated in tasks/shared_utils.py

# auto resume the latest checkpoint
if config.get("auto_resume", False):
logger.info("Auto resuming")
model_latest = join(config.output_dir, "ckpt_latest.pth")
model_best = join(config.output_dir, "ckpt_best.pth")
large_num = -1
for p in os.listdir(config.output_dir):
if 'ckpt' in p:
num = p.split('_')[1].split('.')[0]
if str.isnumeric(num):
if int(num) > large_num:
large_num = int(num)
if large_num != -1:
model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth")
if osp.isfile(model_latest):
config.pretrained_path = model_latest
config.resume = True
elif osp.isfile(model_best):
config.pretrained_path = model_best
config.resume = True
else:
logger.info(f"Not found checkpoint in {config.output_dir}")
if osp.isfile(config.pretrained_path):
checkpoint = torch.load(config.pretrained_path, map_location="cpu")
if 'model' in checkpoint.keys():
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
if config.resume:
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
scaler.load_state_dict(checkpoint["scaler"])
start_epoch = checkpoint["epoch"] + 1
global_step = checkpoint["global_step"]
elif not pretrain: # downstream init from pretrained ckpt
# interpolate positional embeddings.
if "vit" in config.model.vision_encoder.name:
pass
else:
raise ValueError(
f" vision encoder: {config.model.vision_encoder.name} not implelented"
)
if not config.evaluate or config.get("zero_shot", False): # finetuning from a pretarined weights.
for key in list(state_dict.keys()):
if "bert" in key:
encoder_key = key.replace("bert.", "")
state_dict[encoder_key] = state_dict[key]
if not has_decoder:
del state_dict[key]
# init text decoder as multimodal encoder (last 6 layers of model.text_encoder)
# only for generation tasks like VQA
if has_decoder and "text_encoder" in key:
if "layer" in key:
encoder_keys = key.split(".")
layer_num = int(encoder_keys[4])
if layer_num < config.model.text_encoder.fusion_layer:
del state_dict[key]
continue
else:
decoder_layer_num = layer_num - config.model.text_encoder.fusion_layer
encoder_keys[4] = str(decoder_layer_num)
encoder_key = ".".join(encoder_keys)
else:
encoder_key = key
decoder_key = encoder_key.replace("text_encoder", "text_decoder")
state_dict[decoder_key] = state_dict[key]
del state_dict[key]
msg = model_without_ddp.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"Loaded checkpoint from {config.pretrained_path}")
else:
logger.warning("No pretrained checkpoint provided, training from scratch")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants