From 7e218202df8f44cf359265a6cc213d1187762419 Mon Sep 17 00:00:00 2001 From: Yingkai Sha Date: Sun, 29 Dec 2024 08:59:14 -0700 Subject: [PATCH] bugfix model loading for rollout_metrics.py --- applications/rollout_metrics.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/applications/rollout_metrics.py b/applications/rollout_metrics.py index e847a21..9722c71 100644 --- a/applications/rollout_metrics.py +++ b/applications/rollout_metrics.py @@ -200,16 +200,28 @@ def predict(rank, world_size, conf, p): drop_last=False, ) - # load model - model = load_model(conf, load_weights=True).to(device) - - # Warning -- see next line + # flag for distributed inference distributed = conf["predict"]["mode"] in ["ddp", "fsdp"] - if distributed: # A new field needs to be added to predict + # ================================================================================ # + if conf["predict"]["mode"] == "none": + model = load_model(conf, load_weights=True).to(device) + + elif conf["predict"]["mode"] == "ddp": + model = load_model(conf).to(device) + # if conf["trainer"].get("compile", False): + # model = torch.compile(model) + model = distributed_model_wrapper(conf, model, device) + ckpt = os.path.join(save_loc, "checkpoint.pt") + checkpoint = torch.load(ckpt, map_location=device) + load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], strict=False) + load_state_dict_error_handler(load_msg) + + elif conf["predict"]["mode"] == "fsdp": + model = load_model(conf, load_weights=True).to(device) model = distributed_model_wrapper(conf, model, device) - if conf["predict"]["mode"] == "fsdp": - # Load model weights (if any), an optimizer, scheduler, and gradient scaler - model = load_model_state(conf, model, device) + # Load model weights (if any), an optimizer, scheduler, and gradient scaler + model = load_model_state(conf, model, device) + # ================================================================================ # model.eval()