Skip to content

Commit

Permalink
bugfix model loading for rollout_metrics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Dec 29, 2024
1 parent 4a36d1d commit 7e21820
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions applications/rollout_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,28 @@ def predict(rank, world_size, conf, p):
drop_last=False,
)

# load model
model = load_model(conf, load_weights=True).to(device)

# Warning -- see next line
# flag for distributed inference
distributed = conf["predict"]["mode"] in ["ddp", "fsdp"]
if distributed: # A new field needs to be added to predict
# ================================================================================ #
if conf["predict"]["mode"] == "none":
model = load_model(conf, load_weights=True).to(device)

elif conf["predict"]["mode"] == "ddp":
model = load_model(conf).to(device)
# if conf["trainer"].get("compile", False):
# model = torch.compile(model)
model = distributed_model_wrapper(conf, model, device)
ckpt = os.path.join(save_loc, "checkpoint.pt")
checkpoint = torch.load(ckpt, map_location=device)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
load_state_dict_error_handler(load_msg)

elif conf["predict"]["mode"] == "fsdp":
model = load_model(conf, load_weights=True).to(device)
model = distributed_model_wrapper(conf, model, device)
if conf["predict"]["mode"] == "fsdp":
# Load model weights (if any), an optimizer, scheduler, and gradient scaler
model = load_model_state(conf, model, device)
# Load model weights (if any), an optimizer, scheduler, and gradient scaler
model = load_model_state(conf, model, device)
# ================================================================================ #

model.eval()

Expand Down

0 comments on commit 7e21820

Please sign in to comment.