-
Notifications
You must be signed in to change notification settings - Fork 359
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7212d0e
commit 77351cf
Showing
24 changed files
with
1,983 additions
and
651 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# DeepSpeed CIFAR Example | ||
This example is adapted from the | ||
[DCGAN example in the DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples/tree/master/training/gan) | ||
repository. It is intended to demonstrate a simple usecase of DeepSpeed with Determined. | ||
|
||
## Files | ||
* **model.py**: The DCGANTrial definition. | ||
* **gan_model.py**: Network definitions for generator and discriminator. | ||
* **data.py**: Dataset loading/downloading code. | ||
|
||
### Configuration Files | ||
* **ds_config.json**: The DeepSpeed config file. | ||
* **mnist.yaml**: Determined config to train the model on mnist on a cluster. | ||
|
||
## Data | ||
This repo supports the same datasets as the original example: `["imagenet", "lfw", "lsun", "cifar10", "mnist", "fake", "celeba"]`. The `cifar10` and `mnist` datasets will be downloaded as needed, whereas the rest must be mounted on the agent. For `lsun`, the `data_config.classes` setting must be set. The `folder` dataset can be used to load an arbitrary torchvision `ImageFolder` that is mounted on the agent. | ||
|
||
## To Run Locally | ||
|
||
It is recommended to run this from within one of our agent docker images, found at | ||
https://hub.docker.com/r/determinedai/pytorch-ngc/tags | ||
|
||
After installing docker and pulling an image, users can launch a container via | ||
`docker run --gpus=all -v ~path/to/repo:/src/proj -it <container name>` | ||
|
||
Install necessary dependencies via `pip install determined mpi4py` | ||
|
||
Then, run the following command: | ||
``` | ||
python trainer.py | ||
``` | ||
|
||
Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly. | ||
|
||
## To Run on Cluster | ||
If you have not yet installed Determined, installation instructions can be found | ||
under `docs/install-admin.html` or at https://docs.determined.ai/latest/index.html | ||
|
||
Run the following command: | ||
``` | ||
det experiment create mnist.yaml . | ||
``` | ||
The other configurations can be run by specifying the appropriate configuration file in place | ||
of `mnist.yaml`. | ||
|
||
## Results | ||
Training `mnist` should yield reasonable looking fake digit images on the images tab in TensorBoard after ~5k steps. | ||
|
||
Training `cifar10` does not converge as convincingly, but should look image-like after ~10k steps. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import contextlib | ||
import os | ||
from typing import cast | ||
|
||
import filelock | ||
import torch | ||
import torchvision.datasets as dset | ||
import torchvision.transforms as transforms | ||
|
||
CHANNELS_BY_DATASET = { | ||
"imagenet": 3, | ||
"folder": 3, | ||
"lfw": 3, | ||
"lsun": 3, | ||
"cifar10": 3, | ||
"mnist": 1, | ||
"fake": 3, | ||
"celeba": 3, | ||
} | ||
|
||
|
||
def get_dataset(data_config: dict) -> torch.utils.data.Dataset: | ||
if data_config.get("dataroot", None) is None: | ||
if str(data_config.get("dataset"),"").lower() != "fake": | ||
raise ValueError('`dataroot` parameter is required for dataset "%s"' | ||
% data_config.get("dataset", "")) | ||
else: | ||
context = contextlib.nullcontext() | ||
else: | ||
# Ensure that only one local process attempts to download/validate datasets at once. | ||
context = filelock.FileLock(os.path.join(data_config["dataroot"], ".lock")) | ||
with context: | ||
if data_config["dataset"] in ["imagenet", "folder", "lfw"]: | ||
# folder dataset | ||
dataset = dset.ImageFolder( | ||
root=data_config["dataroot"], | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "lsun": | ||
classes = [c + "_train" for c in data_config["classes"].split(",")] | ||
dataset = dset.LSUN( | ||
root=data_config["dataroot"], | ||
classes=classes, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "cifar10": | ||
dataset = dset.CIFAR10( | ||
root=data_config["dataroot"], | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "mnist": | ||
dataset = dset.MNIST( | ||
root=data_config["dataroot"], | ||
download=True, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5,), (0.5,)), | ||
] | ||
), | ||
) | ||
elif data_config["dataset"] == "fake": | ||
dataset = dset.FakeData( | ||
image_size=(3, data_config["image_size"], data_config["image_size"]), | ||
transform=transforms.ToTensor(), | ||
) | ||
elif data_config["dataset"] == "celeba": | ||
dataset = dset.ImageFolder( | ||
root=data_config["dataroot"], | ||
transform=transforms.Compose( | ||
[ | ||
transforms.Resize(data_config["image_size"]), | ||
transforms.CenterCrop(data_config["image_size"]), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
] | ||
), | ||
) | ||
else: | ||
unknown_dataset_name = data_config["dataset"] | ||
raise Exception(f"Unknown dataset {unknown_dataset_name}") | ||
return cast(torch.utils.data.Dataset, dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"train_batch_size": 64, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0002, | ||
"betas": [ | ||
0.5, | ||
0.999 | ||
], | ||
"eps": 1e-8 | ||
} | ||
}, | ||
"steps_per_print": 10 | ||
} |
Oops, something went wrong.