Skip to content

Commit e375282

Browse files
committed
Add support for Intel GPU to MNIST examples
* Add support for Intel GPU to MNIST example * Add support for Intel GPU to MNIST Forward-Forward example * Add support for Intel GPU to MNIST using RNN example and update README with optional arguments * Refactor argument parsing in MNIST examples. There is no need to use `default=False` with `store_true` Signed-off-by: jafraustro <jaime.fraustro.valdez@intel.com>
1 parent 5dfeb46 commit e375282

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

mnist/main.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,33 @@ def main():
8282
help='learning rate (default: 1.0)')
8383
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
8484
help='Learning rate step gamma (default: 0.7)')
85-
parser.add_argument('--no-cuda', action='store_true', default=False,
85+
parser.add_argument('--no-cuda', action='store_true',
8686
help='disables CUDA training')
87-
parser.add_argument('--no-mps', action='store_true', default=False,
87+
parser.add_argument('--no-mps', action='store_true',
8888
help='disables macOS GPU training')
89-
parser.add_argument('--dry-run', action='store_true', default=False,
89+
parser.add_argument('--no-xpu', action='store_true',
90+
help='disables Intel GPU training')
91+
parser.add_argument('--dry-run', action='store_true',
9092
help='quickly check a single pass')
9193
parser.add_argument('--seed', type=int, default=1, metavar='S',
9294
help='random seed (default: 1)')
9395
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
9496
help='how many batches to wait before logging training status')
95-
parser.add_argument('--save-model', action='store_true', default=False,
97+
parser.add_argument('--save-model', action='store_true',
9698
help='For Saving the current Model')
9799
args = parser.parse_args()
98100
use_cuda = not args.no_cuda and torch.cuda.is_available()
99101
use_mps = not args.no_mps and torch.backends.mps.is_available()
102+
use_xpu = not args.no_mps and torch.xpu.is_available()
100103

101104
torch.manual_seed(args.seed)
102105

103106
if use_cuda:
104107
device = torch.device("cuda")
105108
elif use_mps:
106109
device = torch.device("mps")
110+
elif use_xpu:
111+
device = torch.device("xpu")
107112
else:
108113
device = torch.device("cpu")
109114

mnist_forward_forward/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ optional arguments:
1818
--lr LR learning rate (default: 0.03)
1919
--no_cuda disables CUDA training
2020
--no_mps disables MPS training
21+
--no_xpu disables XPU training
2122
--seed SEED random seed (default: 1)
2223
--save_model For saving the current Model
2324
--train_size TRAIN_SIZE

mnist_forward_forward/main.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,20 @@ def train(self, x_pos, x_neg):
102102
help="learning rate (default: 0.03)",
103103
)
104104
parser.add_argument(
105-
"--no_cuda", action="store_true", default=False, help="disables CUDA training"
105+
"--no_cuda", action="store_true", help="disables CUDA training"
106106
)
107107
parser.add_argument(
108-
"--no_mps", action="store_true", default=False, help="disables MPS training"
108+
"--no_mps", action="store_true", help="disables MPS training"
109+
)
110+
parser.add_argument(
111+
"--no_xpu", action="store_true", help="disables XPU training"
109112
)
110113
parser.add_argument(
111114
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
112115
)
113116
parser.add_argument(
114117
"--save_model",
115118
action="store_true",
116-
default=False,
117119
help="For saving the current Model",
118120
)
119121
parser.add_argument(
@@ -126,7 +128,6 @@ def train(self, x_pos, x_neg):
126128
parser.add_argument(
127129
"--save-model",
128130
action="store_true",
129-
default=False,
130131
help="For Saving the current Model",
131132
)
132133
parser.add_argument(
@@ -139,10 +140,13 @@ def train(self, x_pos, x_neg):
139140
args = parser.parse_args()
140141
use_cuda = not args.no_cuda and torch.cuda.is_available()
141142
use_mps = not args.no_mps and torch.backends.mps.is_available()
143+
use_xpu = not args.no_xpu and torch.xpu.is_available()
142144
if use_cuda:
143145
device = torch.device("cuda")
144146
elif use_mps:
145147
device = torch.device("mps")
148+
elif use_xpu:
149+
device = torch.device("xpu")
146150
else:
147151
device = torch.device("cpu")
148152

mnist_rnn/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,20 @@ pip install -r requirements.txt
88
python main.py
99
# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2
1010
```
11+
12+
```bash
13+
optional arguments:
14+
-h, --help show this help message and exit
15+
--batch_size input batch_size for training (default:64)
16+
--testing_batch_size input batch size for testing (default: 1000)
17+
--epochs EPOCHS number of epochs to train (default: 14)
18+
--lr LR learning rate (default: 0.1)
19+
--gamma learning rate step gamma (default: 0.7)
20+
--cuda enables CUDA training
21+
--xpu enables XPU training
22+
--mps enables macos GPU training
23+
--seed SEED random seed (default: 1)
24+
--save_model For saving the current Model
25+
--log_interval how many batches to wait before logging training status
26+
--dry-run quickly check a single pass
27+
```

mnist_rnn/main.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,26 @@ def main():
9393
help='learning rate step gamma (default: 0.7)')
9494
parser.add_argument('--cuda', action='store_true', default=False,
9595
help='enables CUDA training')
96-
parser.add_argument('--mps', action="store_true", default=False,
96+
parser.add_argument('--mps', action="store_true",
9797
help="enables MPS training")
98-
parser.add_argument('--dry-run', action='store_true', default=False,
98+
parser.add_argument('--xpu', action='store_true',
99+
help='enables XPU training')
100+
parser.add_argument('--dry-run', action='store_true',
99101
help='quickly check a single pass')
100102
parser.add_argument('--seed', type=int, default=1, metavar='S',
101103
help='random seed (default: 1)')
102104
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
103105
help='how many batches to wait before logging training status')
104-
parser.add_argument('--save-model', action='store_true', default=False,
106+
parser.add_argument('--save-model', action='store_true',
105107
help='for Saving the current Model')
106108
args = parser.parse_args()
107109

108110
if args.cuda and not args.mps:
109111
device = "cuda"
110112
elif args.mps and not args.cuda:
111113
device = "mps"
114+
elif args.xpu:
115+
device = "xpu"
112116
else:
113117
device = "cpu"
114118

0 commit comments

Comments
 (0)