|
13 | 13 | # |
14 | 14 | # PyTorch has two basic data primitives: ``DataSet`` and ``DataLoader``. |
15 | 15 | # The `torchvision.datasets` ``DataSet`` object includes a ``transforms`` mechanism to |
16 | | -# modify data in-place. Below is an example of how to load that data from the Pytorch open datasets and transform the data to a normalized tensor. |
17 | | - |
| 16 | +# modify data in-place. Below is an example of how to load that data from the PyTorch open datasets and transform the data to a normalized tensor. |
18 | 17 | # This example is using the `torchvision.datasets` which is a subclass from the primitive `torch.utils.data.Dataset`. Note that the primitive dataset doesnt have the built in transforms param like the built in dataset in `torchvision.datasets.` |
19 | 18 | # |
20 | | -# To see more examples and details of how to work with Tensors, Datasets, DataLoaders and Transforms in Pytoch with this example checkout these resources: |
| 19 | +# To see more examples and details of how to work with Tensors, Datasets, DataLoaders and Transforms in PyTorch with this example checkout these resources: |
21 | 20 | # |
22 | 21 | # - `Tensors <quickstart/tensor_tutorial.html>`_ |
23 | 22 | # - `DataSet and DataLoader <quickstart/data_quickstart_tutorial.html>`_ |
|
29 | 28 | import matplotlib.pyplot as plt |
30 | 29 | from torch.utils.data import DataLoader |
31 | 30 | from torchvision import datasets, transforms |
| 31 | +import torch.nn.functional as F |
32 | 32 |
|
33 | 33 | classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] |
34 | 34 |
|
|
51 | 51 | train_dataloader = DataLoader(training_data, batch_size=batch_size, num_workers=0, pin_memory=True) |
52 | 52 | test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=0, pin_memory=True) |
53 | 53 |
|
| 54 | +################################ |
54 | 55 | # Creating Models |
55 | 56 | # --------------- |
56 | 57 | # |
57 | 58 | # There are two ways of creating models: in-line or as a class. This |
58 | | -# quickstart will consider an in-line definition. For more examples checkout `building the model <quickstart/build_model_tutorial.html>`_. |
| 59 | +# quickstart will consider a class definition. For more examples checkout `building the model <quickstart/build_model_tutorial.html>`_. |
59 | 60 |
|
60 | 61 | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
61 | 62 | print('Using {} device'.format(device)) |
62 | 63 |
|
63 | | -# in-line model |
64 | | - |
65 | | -model = nn.Sequential( |
66 | | - nn.Flatten(), |
67 | | - nn.Linear(28*28, 512), |
68 | | - nn.ReLU(), |
69 | | - nn.Linear(512, 512), |
70 | | - nn.ReLU(), |
71 | | - nn.Linear(512, len(classes)), |
72 | | - nn.Softmax(dim=1) |
73 | | - ).to(device) |
| 64 | +# Define model |
| 65 | +class NeuralNetwork(nn.Module): |
| 66 | + def __init__(self): |
| 67 | + super(NeuralNetwork, self).__init__() |
| 68 | + self.flatten = nn.Flatten() |
| 69 | + self.layer1 = nn.Linear(28*28, 512) |
| 70 | + self.layer2 = nn.Linear(512, 512) |
| 71 | + self.output = nn.Linear(512, 10) |
| 72 | + |
| 73 | + def forward(self, x): |
| 74 | + x = self.flatten(x) |
| 75 | + x = F.relu(self.layer1(x)) |
| 76 | + x = F.relu(self.layer2(x)) |
| 77 | + x = self.output(x) |
| 78 | + return F.softmax(x, dim=1) |
| 79 | +model = NeuralNetwork().to(device) |
74 | 80 |
|
75 | 81 | print(model) |
76 | 82 |
|
@@ -193,7 +199,7 @@ def test(dataloader, model): |
193 | 199 | print(f'Predicted: "{predicted}", Actual: "{actual}"') |
194 | 200 |
|
195 | 201 | ################################################################## |
196 | | -# Pytorch Quickstart Topics |
| 202 | +# PyTorch Quickstart Topics |
197 | 203 | # ---------------------------------------- |
198 | 204 | # | `Tensors <quickstart/tensor_tutorial.html>`_ |
199 | 205 | # | `DataSets and DataLoaders <quickstart/data_quickstart_tutorial.html>`_ |
|
0 commit comments