Skip to content

Commit 5b406f3

Browse files
authored
Merge pull request #78 from tensorlayer/add_jittor
Add jittor
2 parents 50160c9 + 80f3f6c commit 5b406f3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+8831
-277
lines changed

examples/basic_tutorials/cifar10_cnn.py

Lines changed: 260 additions & 151 deletions
Large diffs are not rendered by default.

examples/basic_tutorials/cifar10_cnn_dist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# -*- coding: utf-8 -*-
33

44
import os
5-
os.environ['TL_BACKEND'] = 'paddle'
5+
# os.environ['TL_BACKEND'] = 'paddle'
6+
os.environ['TL_BACKEND'] = 'jittor'
67
# os.environ['TL_BACKEND'] = 'tensorflow'
78
# os.environ['TL_BACKEND'] = 'mindspore'
89
# os.environ['TL_BACKEND'] = 'torch'

examples/basic_tutorials/cifar10_cnn_train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
# TensorlayerX目前支持包括TensorFlow、Pytorch、PaddlePaddle、MindSpore作为计算后端,指定计算后端的方法也非常简单,只需要设置环境变量即可
55

66
import os
7-
os.environ['TL_BACKEND'] = 'paddle'
7+
# os.environ['TL_BACKEND'] = 'paddle'
8+
9+
os.environ['TL_BACKEND'] = 'jittor'
810
# os.environ['TL_BACKEND'] = 'tensorflow'
911
# os.environ['TL_BACKEND'] = 'mindspore'
1012
# os.environ['TL_BACKEND'] = 'torch'
1113

14+
15+
1216
import tensorlayerx as tlx
1317
from tensorlayerx.nn import Module
1418
from tensorlayerx.nn import (Conv2d, Linear, Flatten, MaxPool2d, BatchNorm2d)
@@ -54,6 +58,7 @@ def forward(self, x):
5458
z = self.linear1(z)
5559
z = self.linear2(z)
5660
return z
61+
5762

5863

5964
# get the network
@@ -70,7 +75,7 @@ def forward(self, x):
7075

7176
# 定义损失函数、优化器等
7277
loss_fn=tlx.losses.softmax_cross_entropy_with_logits
73-
optimizer = tlx.optimizers.Adam(learning_rate)
78+
optimizer = tlx.optimizers.Adam(lr=learning_rate)
7479
metrics = tlx.metrics.Accuracy()
7580

7681

examples/basic_tutorials/gradient_clip_mixed_tensorflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
# -*- coding: utf-8 -*-
33
# The tensorlayerx and tensorflow operators can be mixed
44
import os
5-
os.environ['TL_BACKEND'] = 'tensorflow'
5+
# os.environ['TL_BACKEND'] = 'tensorflow'
66
# os.environ['TL_BACKEND'] = 'paddle'
77
# os.environ['TL_BACKEND'] = 'torch'
8+
os.environ['TL_BACKEND'] = 'jittor'
89

910

1011
import time

0 commit comments

Comments
 (0)