Skip to content

Commit e69bec7

Browse files
committed
Implement DCGAN example with distributed support (DDP)
- DDP is recommended instead of DataParallel https://docs.pytorch.org/docs/stable/generated/torch.nn.DataParallel.html - Fix original DCGAN example, `DataParallel` does not suport xpu devices - Improve code readability Signed-off-by: jafraustro <jaime.fraustro.valdez@intel.com>
1 parent a630ec6 commit e69bec7

File tree

7 files changed

+458
-2
lines changed

7 files changed

+458
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ cpp/mnist/build
66
cpp/dcgan/build
77
dcgan/*.png
88
dcgan/*.pth
9+
distributed/dcgan/*.png
10+
distributed/dcgan/*.pth
911
snli/.data
1012
snli/.vector_cache
1113
snli/results
@@ -23,3 +25,4 @@ docs/venv
2325
# development
2426
.vscode
2527
**/.DS_Store
28+

dcgan/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(self, ngpu):
152152

153153
def forward(self, input):
154154

155-
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
155+
if (input.is_cuda) and self.ngpu > 1:
156156
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
157157
else:
158158
output = self.main(input)
@@ -192,7 +192,7 @@ def __init__(self, ngpu):
192192
)
193193

194194
def forward(self, input):
195-
if (input.is_cuda or input.is_xpu) and self.ngpu > 1:
195+
if (input.is_cuda) and self.ngpu > 1:
196196
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
197197
else:
198198
output = self.main(input)

distributed/dcgan/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Deep Convolution Generative Adversarial Networks
2+
3+
This example implements the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434)
4+
5+
The implementation is very close to the Torch implementation [dcgan.torch](https://github.com/soumith/dcgan.torch)
6+
7+
After every 100 training iterations, the files `real_samples.png` and `fake_samples.png` are written to disk
8+
with the samples from the generative model.
9+
10+
After every epoch, models are saved to: `netG_epoch_%d.pth` and `netD_epoch_%d.pth`
11+
12+
## Downloading the dataset
13+
14+
You can download the LSUN dataset by cloning [this repo](https://github.com/fyu/lsun) and running
15+
16+
```
17+
python download.py -c bedroom
18+
```
19+
20+
## Installation
21+
22+
```bash
23+
pip install -r requirements.txt
24+
```
25+
26+
## Running Examples
27+
28+
You can run the examples using `torchrun` to launch distributed training:
29+
30+
```bash
31+
torchrun --nnodes=1 --nproc_per_node=4 main.py --dataset fake
32+
```
33+
34+
For more details, check the `run_examples.sh` script.
35+
36+
## Usage
37+
38+
```
39+
usage: main.py [-h] --dataset DATASET [--dataroot DATAROOT] [--workers WORKERS]
40+
[--batchSize BATCHSIZE] [--imageSize IMAGESIZE] [--nz NZ] [--niter NITER]
41+
[--lr LR] [--beta1 BETA1] [--dry-run] [--ngf NGF] [--ndf NDF] [--netG NETG]
42+
[--netD NETD] [--outf OUTF] [--manualSeed MANUALSEED] [--classes CLASSES]
43+
44+
options:
45+
-h, --help show this help message and exit
46+
--dataset DATASET cifar10 | lsun | mnist |imagenet | folder | lfw | fake
47+
--dataroot DATAROOT path to dataset
48+
--workers WORKERS number of data loading workers
49+
--batchSize BATCHSIZE input batch size
50+
--imageSize IMAGESIZE the height / width of the input image to network
51+
--nz NZ size of the latent z vector
52+
--niter NITER number of epochs to train for
53+
--lr LR learning rate, default=0.0002
54+
--beta1 BETA1 beta1 for adam. default=0.5
55+
--dry-run check a single training cycle works
56+
--ngf NGF
57+
--ndf NDF
58+
--netG NETG path to netG (to continue training)
59+
--netD NETD path to netD (to continue training)
60+
--outf OUTF folder to output images and model checkpoints
61+
--manualSeed MANUALSEED manual seed
62+
--classes CLASSES comma separated list of classes for the lsun data set
63+
```

0 commit comments

Comments
 (0)