-
Notifications
You must be signed in to change notification settings - Fork 594
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jacob Sela
committed
Dec 30, 2024
1 parent
e08a447
commit fa8db3d
Showing
2 changed files
with
284 additions
and
4 deletions.
There are no files selected for viewing
249 changes: 249 additions & 0 deletions
249
docs/source/recipes/torch-dataset-examples/ddp_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
from argparse import ArgumentParser | ||
import os | ||
|
||
import fiftyone as fo | ||
from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey | ||
|
||
import torch | ||
from tqdm import tqdm | ||
import numpy as np | ||
|
||
import utils | ||
|
||
|
||
def main(local_rank, dataset_name, num_classes, num_epochs, save_dir): | ||
|
||
torch.distributed.init_process_group() | ||
|
||
# setup local groups | ||
local_group = None | ||
for n in range( | ||
int( | ||
int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"]) | ||
) | ||
): | ||
aux = torch.distributed.new_group() | ||
torch.distributed.barrier() | ||
if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n: | ||
local_group = aux | ||
local_broadcast_process_authkey(local_group) | ||
|
||
model = utils.setup_ddp_model(num_classes=num_classes) | ||
model.to(DEVICES[local_rank]) | ||
ddp_model = torch.nn.parallel.DistributedDataParallel( | ||
model, device_ids=[DEVICES[local_rank]] | ||
) | ||
|
||
loss_function = torch.nn.CrossEntropyLoss(reduction="none") | ||
|
||
dataset = None | ||
# synchronously load dataset in each trainer | ||
for r in range(int(os.environ["LOCAL_WORLD_SIZE"])): | ||
if local_rank == r: | ||
dataset = fo.load_dataset(dataset_name) | ||
torch.distributed.barrier(local_group) | ||
|
||
dataloaders = utils.create_dataloaders_ddp( | ||
dataset, | ||
utils.mnist_get_item, | ||
local_process_group=local_group, | ||
num_workers=4, | ||
batch_size=16, | ||
persistent_workers=True, | ||
) | ||
optimizer = utils.setup_optim(ddp_model) | ||
|
||
best_epoch = None | ||
best_loss = np.inf | ||
for epoch in range(num_epochs): | ||
train_epoch( | ||
local_rank, | ||
ddp_model, | ||
dataloaders["train"], | ||
loss_function, | ||
optimizer, | ||
) | ||
validation_loss = validation( | ||
local_rank, | ||
ddp_model, | ||
dataloaders["validation"], | ||
dataset, | ||
loss_function, | ||
) | ||
|
||
# average over all trainers | ||
validation_loss = np.mean(all_gather(validation_loss)) | ||
|
||
if validation_loss < best_loss: | ||
best_loss = validation_loss | ||
best_epoch = epoch | ||
if local_rank == 0: | ||
print(f"New best lost achieved : {best_loss}. Saving model...") | ||
torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch}.pt") | ||
|
||
torch.distributed.barrier() | ||
|
||
model = utils.setup_ddp_model( | ||
num_classes=num_classes, | ||
weights_path=f"{save_dir}/epoch_{best_epoch}.pt", | ||
).to(DEVICES[local_rank]) | ||
model.to(DEVICES[local_rank]) | ||
ddp_model = torch.nn.parallel.DistributedDataParallel( | ||
model, device_ids=[DEVICES[local_rank]] | ||
) | ||
test_loss = validation( | ||
local_rank, | ||
ddp_model, | ||
dataloaders["test"], | ||
dataset, | ||
loss_function, | ||
save_results=True, | ||
) | ||
test_loss = np.mean(all_gather(test_loss)) | ||
classes = [ | ||
utils.mnist_index_to_label_string(i) for i in range(num_classes) | ||
] | ||
if local_rank == 0: | ||
results = dataset.match_tags("test").evaluate_classifications( | ||
"predictions", | ||
gt_field="ground_truth", | ||
eval_key="eval", | ||
classes=classes, | ||
k=3, | ||
) | ||
|
||
print("Final Test Results:") | ||
print(f"Loss = {test_loss}") | ||
results.print_report(classes=classes) | ||
|
||
torch.distributed.destroy_process_group(torch.distributed.group.WORLD) | ||
|
||
|
||
def train_epoch(local_rank, model, dataloader, loss_function, optimizer): | ||
model.train() | ||
|
||
cummulative_loss = 0 | ||
pbar = ( | ||
tqdm(enumerate(dataloader), total=len(dataloader)) | ||
if local_rank == 0 | ||
else enumerate(dataloader) | ||
) | ||
for batch_num, batch in pbar: | ||
batch["image"] = batch["image"].to(DEVICES[local_rank]) | ||
batch["label"] = batch["label"].to(DEVICES[local_rank]) | ||
|
||
prediction = model(batch["image"]) | ||
loss = torch.mean(loss_function(prediction, batch["label"])) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
cummulative_loss = cummulative_loss + loss.detach().cpu().numpy() | ||
if local_rank == 0: | ||
if batch_num % 100 == 0: | ||
pbar.set_description( | ||
f"Average Train Loss = {cummulative_loss / (batch_num + 1):10f}" | ||
) | ||
return cummulative_loss / (batch_num + 1) | ||
|
||
|
||
@torch.no_grad() | ||
def validation( | ||
local_rank, model, dataloader, dataset, loss_function, save_results=False | ||
): | ||
model.eval() | ||
|
||
cummulative_loss = 0 | ||
pbar = ( | ||
tqdm(enumerate(dataloader), total=len(dataloader)) | ||
if local_rank == 0 | ||
else enumerate(dataloader) | ||
) | ||
for batch_num, batch in pbar: | ||
with torch.no_grad(): | ||
batch["image"] = batch["image"].to(DEVICES[local_rank]) | ||
batch["label"] = batch["label"].to(DEVICES[local_rank]) | ||
|
||
prediction = model(batch["image"]) | ||
loss_individual = ( | ||
loss_function(prediction, batch["label"]) | ||
.detach() | ||
.cpu() | ||
.numpy() | ||
) | ||
|
||
if save_results: | ||
samples = dataset._dataset.select(batch["id"]) | ||
samples.set_values("loss", loss_individual.tolist()) | ||
|
||
fo_predictions = [ | ||
fo.Classification( | ||
label=utils.mnist_index_to_label_string( | ||
np.argmax(sample_logits) | ||
), | ||
logits=sample_logits, | ||
) | ||
for sample_logits in prediction.detach().cpu().numpy() | ||
] | ||
samples.set_values("predictions", fo_predictions) | ||
samples.save() | ||
|
||
cummulative_loss = cummulative_loss + np.mean(loss_individual) | ||
if local_rank == 0: | ||
if batch_num % 100 == 0: | ||
pbar.set_description( | ||
f"Average Validation Loss = {cummulative_loss / (batch_num + 1):10f}" | ||
) | ||
return cummulative_loss / (batch_num + 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# run with | ||
# torchrun --nnodes=1 --nproc-per-node=6 \ | ||
# PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \ | ||
# -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7 | ||
|
||
argparser = ArgumentParser() | ||
argparser.add_argument( | ||
"-d", "--dataset", type=str, help="name of fiftyone dataset" | ||
) | ||
argparser.add_argument( | ||
"-n", | ||
"--num_classes", | ||
type=int, | ||
help="number of classes in the dataset", | ||
) | ||
argparser.add_argument( | ||
"-e", | ||
"--epochs", | ||
type=int, | ||
help="number of epochs to train for", | ||
default=5, | ||
) | ||
argparser.add_argument( | ||
"-s", | ||
"--save_dir", | ||
type=str, | ||
help="directory to save checkpoints to", | ||
default="~/mnist_weights", | ||
) | ||
argparser.add_argument( | ||
"--devices", default=range(torch.cuda.device_count()), nargs="*" | ||
) | ||
|
||
args = argparser.parse_args() | ||
|
||
assert int(os.environ["LOCAL_WORLD_SIZE"]) == len(args.devices) | ||
|
||
DEVICES = [torch.device(f"cuda:{d}") for d in args.devices] | ||
|
||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
|
||
torch.multiprocessing.set_start_method("forkserver") | ||
torch.multiprocessing.set_forkserver_preload(["torch", "fiftyone"]) | ||
|
||
main( | ||
local_rank, args.dataset, args.num_classes, args.epochs, args.save_dir | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters