From 6d2e9a483b3a7a413d3f63725be62058995b52c2 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sun, 3 Mar 2024 00:32:00 +0400 Subject: [PATCH] Update VHR-10 snippet (#1920) * Update VHR-10 snippet * Remove augs --- README.md | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b14a3791f3b..215d882095b 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,29 @@ TorchGeo includes a number of [*benchmark datasets*](https://torchgeo.readthedoc If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class ([VHR-10](https://github.com/chaozhong2010/VHR-10_dataset_coco)) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision. ```python +from torch.utils.data import DataLoader + +from torchgeo.datamodules.utils import collate_fn_detection +from torchgeo.datasets import VHR10 + +# Initialize the dataset dataset = VHR10(root="...", download=True, checksum=True) -dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) +# Initialize the dataloader with the custom collate function +dataloader = DataLoader( + dataset, + batch_size=128, + shuffle=True, + num_workers=4, + collate_fn=collate_fn_detection, +) + +# Training loop for batch in dataloader: - image = batch["image"] - label = batch["label"] + images = batch["image"] # list of images + boxes = batch["boxes"] # list of boxes + labels = batch["labels"] # list of labels + masks = batch["masks"] # list of masks # train a model, or make predictions using a pre-trained model ```