Skip to content

Commit

Permalink
mnist ddp exmaple
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Sela committed Dec 30, 2024
1 parent e08a447 commit fa8db3d
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 4 deletions.
249 changes: 249 additions & 0 deletions docs/source/recipes/torch-dataset-examples/ddp_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from argparse import ArgumentParser
import os

import fiftyone as fo
from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey

import torch
from tqdm import tqdm
import numpy as np

import utils


def main(local_rank, dataset_name, num_classes, num_epochs, save_dir):

torch.distributed.init_process_group()

# setup local groups
local_group = None
for n in range(
int(
int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"])
)
):
aux = torch.distributed.new_group()
torch.distributed.barrier()
if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n:
local_group = aux
local_broadcast_process_authkey(local_group)

model = utils.setup_ddp_model(num_classes=num_classes)
model.to(DEVICES[local_rank])
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[DEVICES[local_rank]]
)

loss_function = torch.nn.CrossEntropyLoss(reduction="none")

dataset = None
# synchronously load dataset in each trainer
for r in range(int(os.environ["LOCAL_WORLD_SIZE"])):
if local_rank == r:
dataset = fo.load_dataset(dataset_name)
torch.distributed.barrier(local_group)

dataloaders = utils.create_dataloaders_ddp(
dataset,
utils.mnist_get_item,
local_process_group=local_group,
num_workers=4,
batch_size=16,
persistent_workers=True,
)
optimizer = utils.setup_optim(ddp_model)

best_epoch = None
best_loss = np.inf
for epoch in range(num_epochs):
train_epoch(
local_rank,
ddp_model,
dataloaders["train"],
loss_function,
optimizer,
)
validation_loss = validation(
local_rank,
ddp_model,
dataloaders["validation"],
dataset,
loss_function,
)

# average over all trainers
validation_loss = np.mean(all_gather(validation_loss))

if validation_loss < best_loss:
best_loss = validation_loss
best_epoch = epoch
if local_rank == 0:
print(f"New best lost achieved : {best_loss}. Saving model...")
torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch}.pt")

torch.distributed.barrier()

model = utils.setup_ddp_model(
num_classes=num_classes,
weights_path=f"{save_dir}/epoch_{best_epoch}.pt",
).to(DEVICES[local_rank])
model.to(DEVICES[local_rank])
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[DEVICES[local_rank]]
)
test_loss = validation(
local_rank,
ddp_model,
dataloaders["test"],
dataset,
loss_function,
save_results=True,
)
test_loss = np.mean(all_gather(test_loss))
classes = [
utils.mnist_index_to_label_string(i) for i in range(num_classes)
]
if local_rank == 0:
results = dataset.match_tags("test").evaluate_classifications(
"predictions",
gt_field="ground_truth",
eval_key="eval",
classes=classes,
k=3,
)

print("Final Test Results:")
print(f"Loss = {test_loss}")
results.print_report(classes=classes)

torch.distributed.destroy_process_group(torch.distributed.group.WORLD)


def train_epoch(local_rank, model, dataloader, loss_function, optimizer):
model.train()

cummulative_loss = 0
pbar = (
tqdm(enumerate(dataloader), total=len(dataloader))
if local_rank == 0
else enumerate(dataloader)
)
for batch_num, batch in pbar:
batch["image"] = batch["image"].to(DEVICES[local_rank])
batch["label"] = batch["label"].to(DEVICES[local_rank])

prediction = model(batch["image"])
loss = torch.mean(loss_function(prediction, batch["label"]))

loss.backward()
optimizer.step()
optimizer.zero_grad()

cummulative_loss = cummulative_loss + loss.detach().cpu().numpy()
if local_rank == 0:
if batch_num % 100 == 0:
pbar.set_description(
f"Average Train Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


@torch.no_grad()
def validation(
local_rank, model, dataloader, dataset, loss_function, save_results=False
):
model.eval()

cummulative_loss = 0
pbar = (
tqdm(enumerate(dataloader), total=len(dataloader))
if local_rank == 0
else enumerate(dataloader)
)
for batch_num, batch in pbar:
with torch.no_grad():
batch["image"] = batch["image"].to(DEVICES[local_rank])
batch["label"] = batch["label"].to(DEVICES[local_rank])

prediction = model(batch["image"])
loss_individual = (
loss_function(prediction, batch["label"])
.detach()
.cpu()
.numpy()
)

if save_results:
samples = dataset._dataset.select(batch["id"])
samples.set_values("loss", loss_individual.tolist())

fo_predictions = [
fo.Classification(
label=utils.mnist_index_to_label_string(
np.argmax(sample_logits)
),
logits=sample_logits,
)
for sample_logits in prediction.detach().cpu().numpy()
]
samples.set_values("predictions", fo_predictions)
samples.save()

cummulative_loss = cummulative_loss + np.mean(loss_individual)
if local_rank == 0:
if batch_num % 100 == 0:
pbar.set_description(
f"Average Validation Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


if __name__ == "__main__":

# run with
# torchrun --nnodes=1 --nproc-per-node=6 \
# PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \
# -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7

argparser = ArgumentParser()
argparser.add_argument(
"-d", "--dataset", type=str, help="name of fiftyone dataset"
)
argparser.add_argument(
"-n",
"--num_classes",
type=int,
help="number of classes in the dataset",
)
argparser.add_argument(
"-e",
"--epochs",
type=int,
help="number of epochs to train for",
default=5,
)
argparser.add_argument(
"-s",
"--save_dir",
type=str,
help="directory to save checkpoints to",
default="~/mnist_weights",
)
argparser.add_argument(
"--devices", default=range(torch.cuda.device_count()), nargs="*"
)

args = argparser.parse_args()

assert int(os.environ["LOCAL_WORLD_SIZE"]) == len(args.devices)

DEVICES = [torch.device(f"cuda:{d}") for d in args.devices]

local_rank = int(os.environ["LOCAL_RANK"])

torch.multiprocessing.set_start_method("forkserver")
torch.multiprocessing.set_forkserver_preload(["torch", "fiftyone"])

main(
local_rank, args.dataset, args.num_classes, args.epochs, args.save_dir
)
39 changes: 35 additions & 4 deletions docs/source/recipes/torch-dataset-examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def mnist_index_to_label_string(index):


convert_and_normalize = transforms.Compose(
[
transforms.ToTensor(),
]
[transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)]
)


Expand Down Expand Up @@ -142,7 +140,7 @@ def create_dataloaders(
def setup_model(num_classes, weights_path=None):
model = resnet18(weights=ResNet18_Weights.DEFAULT)
linear_head = torch.nn.Linear(512, num_classes)
torch.nn.init.xavier_uniform(linear_head.weight)
torch.nn.init.xavier_uniform_(linear_head.weight)
model.fc = linear_head
if weights_path is not None:
model.load_state_dict(torch.load(weights_path, weights_only=True))
Expand All @@ -154,6 +152,39 @@ def setup_optim(model, lr=0.01, l2=0.00001):
return optimizer


### DDP utils ###


def setup_ddp_model(**kwargs):
model = setup_model(**kwargs)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model


def create_dataloaders_ddp(
dataset, get_item, cache_fields=None, local_process_group=None, **kwargs
):
split_tags = ["train", "validation", "test"]
dataloaders = {}
for split_tag in split_tags:
split = dataset.match_tags(split_tag).to_torch(
get_item,
cache_fields=cache_fields,
local_process_group=local_process_group,
)
shuffle = True if split_tag == "train" else False
dataloader = torch.utils.data.DataLoader(
split,
worker_init_fn=FiftyOneTorchDataset.worker_init,
sampler=torch.utils.data.DistributedSampler(
split, shuffle=shuffle
),
**kwargs,
)
dataloaders[split_tag] = dataloader
return dataloaders


if __name__ == "__main__":
# this is just here to multiprocessing works when we call these functions in a notebook
pass

0 comments on commit fa8db3d

Please sign in to comment.