Skip to content

Commit 0f56d12

Browse files
aribornsteinsubramen
authored andcommitted
Update data_quickstart_tutorial.py (pytorch#34)
Addressed open comments
1 parent dc66596 commit 0f56d12

File tree

1 file changed

+73
-62
lines changed

1 file changed

+73
-62
lines changed

beginner_source/quickstart/data_quickstart_tutorial.py

Lines changed: 73 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
# Getting Started With Data in PyTorch
88
# -----------------
99
#
10-
# Before we can even think about building a model with PyTorch, we need to first learn how to load and process data. Data can be sourced from local files, cloud datastores and database queries. It comes in all sorts of forms and formats from structured tables to image, audio, text, video files and more.
10+
# Before we start building models with PyTorch, let's first learn how to load and process data. Data can be sourced from local files, cloud datastores and database queries. It comes in all sorts of forms and formats from structured tables to image, audio, text, video files and more.
1111
#
1212

1313
###############################################################
1414
# .. figure:: /_static/img/quickstart/typesdata.png
1515
# :alt: typesdata
16-
#
17-
# Different data types require different python libraries to load and process such as `openCV <https://opencv.org/>`_ and `PIL <https://pillow.readthedocs.io/en/stable/reference/Image.html>`_ for images, `NLTK <https://www.nltk.org/>`_ and `spaCy <https://spacy.io/>`_ for text and `Librosa <https://librosa.org/doc/latest/index.html>`_ for audio.
18-
#
19-
# If not properly organized, code for processing data samples can quickly get messy and become hard to maintain. Since different model architectures can be applied to many data types, we ideally want our dataset code to be decoupled from our model training code. To this end, PyTorch provides a simple Datasets interface for linking managing collections of data.
20-
#
21-
# A whole set of example datasets such as Fashion MNIST that implement this interface are built into PyTorch extension libraries. These are useful for benchmarking and testing your models before training on your own custom datasets.
22-
#
23-
# You can find some of them below.
16+
#
17+
18+
############################################################
19+
# Different data types require different python libraries to load and process such as `openCV <https://opencv.org/>`_ and `PIL <https://pillow.readthedocs.io/en/stable/reference/Image.html>`_ for images, `NLTK <https://www.nltk.org/>`_ and `spaCy <https://spacy.io/>`_ for text and `Librosa <https://librosa.org/doc/latest/index.html>`_ for audio.
20+
#
21+
# If not properly organized, code for processing data samples can quickly get messy and become hard to maintain. Since different model architectures can be applied to many data types, we ideally want our dataset code to be decoupled from our model training code. To this end, PyTorch provides a simple Datasets interface for linking managing collections of data.
22+
#
23+
# A whole set of example datasets such as Fashion MNIST that implement this interface are built into PyTorch extension libraries. They are subclasses of torch.utils.data.Dataset that have parameters and functions specific to the type of data and the particular dataset. The actual data samples can be downloaded from the internet.These are useful for benchmarking and testing your models before training on your own custom datasets.
24+
#
25+
# You can find some of them below.
2426
#
2527
# - `Image Datasets <https://pytorch.org/docs/stable/torchvision/datasets.html>`_
2628
# - `Text Datasets <https://pytorch.org/text/datasets.html>`_
@@ -30,32 +32,29 @@
3032
#################################################################
3133
# Iterating through a Dataset
3234
# -----------------
33-
#
34-
# Once we have a Dataset we can index it manually like a list `clothing[index]`.
35-
#
36-
# Here is an example of how to load the fashion MNIST dataset from torch vision.
37-
#
35+
#
36+
# Once we have a Dataset we can index it manually like a list `clothing[index]`.
37+
#
38+
# Here is an example of how to load the [Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) dataset from torch vision. "[Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) is a dataset of Zalando’s article images consisting of of 60,000 training examples and 10,000 test examples. Each example is comprised of a 28×28 grayscale image, associated with a label from one of 10 classes. Read more [here](https://pytorch.org/docs/stable/torchvision/datasets.html#fashion-mnist).
39+
# To load the FashionMNIST Dataset we need to provide the following three parameters:
40+
# - root is the path where the train/test data is stored.
41+
# - train includes the training dataset.
42+
# - setting download to true downloads the data from the internet if it's not available at root.
3843

39-
from torch.utils.data import DataLoader
40-
from torchvision.io import read_image
41-
from torchvision import transforms, utils
42-
import pandas as pd
43-
import torch
44-
import os
45-
import torch
44+
45+
import torch
4646
from torch.utils.data import Dataset
4747
import torchvision.datasets as datasets
4848
import matplotlib.pyplot as plt
4949
import numpy as np
5050

51-
clothing = datasets.FashionMNIST('data', train=True, download=True)
52-
labels_map = {0: 'T-Shirt', 1: 'Trouser', 2: 'Pullover', 3: 'Dress',
53-
4: 'Coat', 5: 'Sandal', 6: 'Shirt', 7: 'Sneaker', 8: 'Bag', 9: 'Ankle Boot'}
54-
figure = plt.figure(figsize=(8, 8))
51+
clothing = datasets.FashionMNIST(root='data', train=True, download=True)
52+
labels_map = {0 : 'T-Shirt', 1 : 'Trouser', 2 : 'Pullover', 3 : 'Dress', 4 : 'Coat', 5 : 'Sandal', 6 : 'Shirt', 7 : 'Sneaker', 8 : 'Bag', 9 : 'Ankle Boot'}
53+
figure = plt.figure(figsize=(8,8))
5554
cols, rows = 3, 3
56-
for i in range(1, cols*rows + 1):
55+
for i in range(1, cols*rows +1):
5756
sample_idx = np.random.randint(len(clothing))
58-
img = clothing[sample_idx][0][0, :, :]
57+
img = clothing[sample_idx][0][0,:,:]
5958
figure.add_subplot(rows, cols, i)
6059
plt.title(labels_map[clothing[sample_idx][1]])
6160
plt.axis('off')
@@ -74,6 +73,12 @@
7473
# To work with your own data lets look at the a simple custom image Dataset implementation:
7574
#
7675

76+
import os
77+
import torch
78+
import pandas as pd
79+
from torch.utils.data import Dataset
80+
from torchvision import transforms, utils
81+
from torchvision.io import read_image
7782

7883
class CustomImageDataset(Dataset):
7984
def __init__(self, annotations_file, img_dir, transform=None):
@@ -88,35 +93,46 @@ def __getitem__(self, idx):
8893
if torch.is_tensor(idx):
8994
idx = idx.tolist()
9095

91-
img_name = os.path.join(self.root_dir,
96+
img_path = os.path.join(self.root_dir,
9297
self.img_labels.iloc[idx, 0])
93-
image = read_image('path_to_image.jpeg')
98+
image = read_image(img_path)
9499
label = self.img_labels.iloc[idx, 1:]
95100
sample = {'image': image, 'label': label}
96101

97102
if self.transform:
98103
sample = self.transform(sample)
99104

100-
return sample
101-
105+
return sample
106+
102107
#################################################################
103-
# Imports
104-
# -----------------
105-
#
108+
# Imports
109+
# -------
110+
#
106111
# Import os for file handling, torch for PyTorch, `pandas <https://pandas.pydata.org/>`_ for loading labels, `torch vision <https://pytorch.org/blog/pytorch-1.7-released/>`_ to read image files, and Dataset to implement the Dataset interface.
107-
#
112+
#
108113
# Example:
114+
#
109115

116+
import os
117+
import torch
118+
import pandas as pd
119+
from torchvision.io import read_image
120+
from torch.utils.data import Dataset
121+
from torch.utils.data import DataLoader
110122

111123
#################################################################
112124
# Init
113125
# -----------------
114126
#
115127
# The init function is used for all the first time operations when our Dataset is loaded. In this case we use it to load our annotation labels to memory and the keep track of directory of our image file. Note that different types of data can take different init inputs you are not limited to just an annotations file, directory_path and transforms but for images this is a standard practice.
116-
#
128+
# A sample csv annotations file may look as follows:
129+
# tshirt1.jpg, 0
130+
# tshirt2.jpg, 0
131+
# ......
132+
# ankleboot999.jpg, 9
133+
#
117134
# Example:
118-
#
119-
135+
#
120136

121137
def __init__(self, annotations_file, img_dir, transform=None):
122138
self.img_labels = pd.read_csv(annotations_file)
@@ -127,11 +143,10 @@ def __init__(self, annotations_file, img_dir, transform=None):
127143
# __len__
128144
# -----------------
129145
#
130-
# The __len__ function is very simple here we just need to return the number of samples in our dataset.
131-
#
146+
# The __len__ function is very simple here we just need to return the number of samples in our dataset.
147+
#
132148
# Example:
133149

134-
135150
def __len__(self):
136151
return len(self.img_labels)
137152

@@ -140,46 +155,42 @@ def __len__(self):
140155
# -----------------
141156
#
142157
# The __getitem__ function is the most important function in the Datasets interface this. It takes a tensor or an index as input and returns a loaded sample from you dataset at from the given indecies.
143-
#
144-
# In this sample if provided a tensor we convert the tensor to a list containing our index. We then load the file at the given index from our image directory as well as the image label from our pandas annotations DataFrame. This image and label are then wrapped in a single sample dictionary which we can apply a Transform on and return. To learn more about Transforms see the next section of the Blitz.
145-
#
158+
#
159+
# In this sample if provided a tensor we convert the tensor to a list containing our index. We then load the file at the given index from our image directory as well as the image label from our pandas annotations DataFrame. This image and label are then wrapped in a single sample dictionary which we can apply a Transform on and return. To learn more about Transforms see the next section of the Blitz.
160+
#
146161
# Example:
147-
162+
#
148163

149164
def __getitem__(self, idx):
150165
if torch.is_tensor(idx):
151166
idx = idx.tolist()
152-
img_name = os.path.join(self.root_dir,
167+
img_path = os.path.join(self.root_dir,
153168
self.img_labels.iloc[idx, 0])
154-
image = read_image('path_to_image.jpeg')
169+
image = read_image(img_path)
155170
label = self.img_labels.iloc[idx, 1:]
156171
sample = {'image': image, 'label': label}
157172
if self.transform:
158173
sample = self.transform(sample)
159-
return sample
174+
return sample
160175

161176
#################################################################
162177
# Preparing your data for training with DataLoaders
163-
# -----------------
164-
#
165-
# Now we have a organized mechansim for managing data which is great, but there is still a lot of manual work we would have to do train a model with our Dataset.
166-
#
167-
# For example we would have to manually maintain the code for:
168-
# * Batching
169-
# * Suffling
170-
# * Parallel batch distribution
171-
#
178+
# -------------------------------------------------
179+
#
180+
# Now we have a organized mechansim for managing data which is great, but there is still a lot of manual work we would have to do train a model with our Dataset.
181+
#
182+
# For example we would have to manually maintain the code for:
183+
# * Batching
184+
# * Suffling
185+
# * Parallel batch distribution
186+
#
172187
# The PyTorch Dataloader *torch.utils.data.DataLoader* is an iterator that handles all of this complexity for us enabling us to load a dataset and focusing on train our model.
173188

174-
175189
dataloader = DataLoader(clothing, batch_size=4, shuffle=True, num_workers=0)
176190

177191
#################################################################
178192
# With this we have all we need to know to load an process data of any kind in PyTorch to train deep learning models.
179-
#
193+
#
180194
# Next: Learn more about how to `transform that data for training <transforms_tutorial.html>`_.
181195
#
182-
183-
##################################################################
184196
# .. include:: /beginner_source/quickstart/qs_toc.txt
185-
#

0 commit comments

Comments
 (0)