Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ To release a new version, please update the changelog as followed:

- `SpatialTransform2dAffine` auto `in_channels`
- support TensorFlow 2.0.0-beta1
- Update model weights property, now returns its copy (#PR 1010)

### Dependencies Update

### Deprecated

### Fixed
- Fix `tf.models.Model._construct_graph` for list of outputs, e.g. STN case (PR #1010)

### Removed

Expand All @@ -89,6 +91,7 @@ To release a new version, please update the changelog as followed:
### Contributors

- @zsdonghao
- @ChrisWu1997: #1010

## [2.1.0]

Expand Down
3 changes: 1 addition & 2 deletions examples/database/dispatch_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,5 @@

# get the best model
print("all tasks finished")
sess = tf.InteractiveSession()
net = db.find_top_model(sess=sess, model_name='mlp', sort=[("test_accuracy", -1)])
net = db.find_top_model(model_name='mlp', sort=[("test_accuracy", -1)])
print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))
76 changes: 32 additions & 44 deletions examples/database/task_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,57 @@
import tensorflow as tf
import tensorlayer as tl

tf.logging.set_verbosity(tf.logging.DEBUG)
# tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)

sess = tf.InteractiveSession()

# connect to database
db = tl.db.TensorHub(ip='localhost', port=27017, dbname='temp', project_name='tutorial')

# load dataset from database
X_train, y_train, X_val, y_val, X_test, y_test = db.find_top_dataset('mnist')

# define placeholder
x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
y_ = tf.placeholder(tf.int64, shape=[None], name='y_')


# define the network
def mlp(x, is_train=True, reuse=False):
with tf.variable_scope("MLP", reuse=reuse):
net = tl.layers.InputLayer(x, name='input')
net = tl.layers.DropoutLayer(net, keep=0.8, is_fix=True, is_train=is_train, name='drop1')
net = tl.layers.DenseLayer(net, n_units=n_units1, act=tf.nn.relu, name='relu1')
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop2')
net = tl.layers.DenseLayer(net, n_units=n_units2, act=tf.nn.relu, name='relu2')
net = tl.layers.DropoutLayer(net, keep=0.5, is_fix=True, is_train=is_train, name='drop3')
net = tl.layers.DenseLayer(net, n_units=10, act=None, name='output')
return net


# define inferences
net_train = mlp(x, is_train=True, reuse=False)
net_test = mlp(x, is_train=False, reuse=True)

# cost for training
y = net_train.outputs
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# cost and accuracy for evalution
y2 = net_test.outputs
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
acc_test = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def mlp():
ni = tl.layers.Input([None, 784], name='input')
net = tl.layers.Dropout(keep=0.8, name='drop1')(ni)
net = tl.layers.Dense(n_units=n_units1, act=tf.nn.relu, name='relu1')(net)
net = tl.layers.Dropout(keep=0.5, name='drop2')(net)
net = tl.layers.Dense(n_units=n_units2, act=tf.nn.relu, name='relu2')(net)
net = tl.layers.Dropout(keep=0.5, name='drop3')(net)
net = tl.layers.Dense(n_units=10, act=None, name='output')(net)
M = tl.models.Model(inputs=ni, outputs=net)
return M

network = mlp()

# cost and accuracy
cost = tl.cost.cross_entropy

def acc(y, y_):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.convert_to_tensor(y_, tf.int64))
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# define the optimizer
train_params = tl.layers.get_variables_with_name('MLP', True, False)
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)

# initialize all variables in the session
sess.run(tf.global_variables_initializer())
train_op = tf.optimizers.Adam(learning_rate=0.0001)

# train the network
# tl.utils.fit(
# network, train_op, cost, X_train, y_train, acc=acc, batch_size=500, n_epoch=20, print_freq=5,
# X_val=X_val, y_val=y_val, eval_train=False
# )

tl.utils.fit(
sess, net_train, train_op, cost, X_train, y_train, x, y_, acc=acc, batch_size=500, n_epoch=1, print_freq=5,
X_val=X_val, y_val=y_val, eval_train=False
network, train_op=tf.optimizers.Adam(learning_rate=0.0001), cost=tl.cost.cross_entropy, X_train=X_train,
y_train=y_train, acc=acc, batch_size=256, n_epoch=20, X_val=X_val, y_val=y_val, eval_train=False,
)

# evaluation and save result that match the result_key
test_accuracy = tl.utils.test(sess, net_test, acc_test, X_test, y_test, x, y_, batch_size=None, cost=cost_test)
test_accuracy = tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=cost)
test_accuracy = float(test_accuracy)

# save model into database
db.save_model(net_train, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
db.save_model(network, model_name='mlp', name=str(n_units1) + '-' + str(n_units2), test_accuracy=test_accuracy)
# in other script, you can load the model as follow
# net = db.find_model(sess=sess, model_name=str(n_units1)+'-'+str(n_units2)

tf.python.keras.layers.BatchNormalization
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def get_model(inputs_shape):

## 2. Spatial transformer module (sampler)
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
s = stn((nn, ni))
nn = stn((nn, ni))
s = nn

## 3. Classifier
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
Expand Down
2 changes: 1 addition & 1 deletion tensorlayer/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def run_top_task(self, task_name=None, sort=None, **kwargs):
logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
_script = _script.decode('utf-8')
with tf.Graph().as_default(): # # as graph: # clear all TF graphs
exec (_script, globals())
exec(_script, globals())

# set status to finished
_ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
Expand Down
3 changes: 0 additions & 3 deletions tensorlayer/layers/spatial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def __repr__(self):
return s.format(classname=self.__class__.__name__, **self.__dict__)

def build(self, inputs_shape):
print("inputs_shape ", inputs_shape)
if self.in_channels is None and len(inputs_shape) != 2:
raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it")
if self.in_channels:
Expand All @@ -267,7 +266,6 @@ def build(self, inputs_shape):
# shape = [inputs_shape[1], 6]
self.in_channels = inputs_shape[0][-1] # zsdonghao
shape = [self.in_channels, 6]
print("shape", shape)
self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros())
identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, ))
self.b = self._get_weights("biases", shape=(6, ), init=tl.initializers.Constant(identity))
Expand All @@ -282,7 +280,6 @@ def forward(self, inputs):
n_channels is identical to that of U.
"""
theta_input, U = inputs
print("inputs", inputs)
theta = tf.nn.tanh(tf.matmul(theta_input, self.W) + self.b)
outputs = transformer(U, theta, out_size=self.out_size)
# automatically set batch_size and channels
Expand Down
8 changes: 5 additions & 3 deletions tensorlayer/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def trainable_weights(self):
if layer.trainable_weights is not None:
self._trainable_weights.extend(layer.trainable_weights)

return self._trainable_weights
return self._trainable_weights.copy()

@property
def nontrainable_weights(self):
Expand All @@ -415,7 +415,7 @@ def nontrainable_weights(self):
if layer.nontrainable_weights is not None:
self._nontrainable_weights.extend(layer.nontrainable_weights)

return self._nontrainable_weights
return self._nontrainable_weights.copy()

@property
def all_weights(self):
Expand All @@ -429,7 +429,7 @@ def all_weights(self):
if layer.all_weights is not None:
self._all_weights.extend(layer.all_weights)

return self._all_weights
return self._all_weights.copy()

@property
def config(self):
Expand Down Expand Up @@ -669,6 +669,8 @@ def _construct_graph(self):

visited_node_names = set()
for out_node in output_nodes:
if out_node.visited:
continue
queue_node.put(out_node)

while not queue_node.empty():
Expand Down
39 changes: 38 additions & 1 deletion tests/layers/test_layernode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import unittest

Expand Down Expand Up @@ -193,6 +192,44 @@ def MyModel():
self.assertEqual(net.all_layers[1].model._nodes_fixed, True)
self.assertEqual(net.all_layers[1].model.all_layers[0]._nodes_fixed, True)

def test_STN(self):
print('-' * 20, 'test STN', '-' * 20)

def get_model(inputs_shape):
ni = Input(inputs_shape)

## 1. Localisation network
# use MLP as the localisation net
nn = Flatten()(ni)
nn = Dense(n_units=20, act=tf.nn.tanh)(nn)
nn = Dropout(keep=0.8)(nn)
# you can also use CNN instead for MLP as the localisation net

## 2. Spatial transformer module (sampler)
stn = SpatialTransformer2dAffine(out_size=(40, 40), in_channels=20)
# s = stn((nn, ni))
nn = stn((nn, ni))
s = nn

## 3. Classifier
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
nn = Conv2d(16, (3, 3), (2, 2), act=tf.nn.relu, padding='SAME')(nn)
nn = Flatten()(nn)
nn = Dense(n_units=1024, act=tf.nn.relu)(nn)
nn = Dense(n_units=10, act=tf.identity)(nn)

M = Model(inputs=ni, outputs=[nn, s])
return M

net = get_model([None, 40, 40, 1])

inputs = np.random.randn(2, 40, 40, 1).astype(np.float32)
o1, o2 = net(inputs, is_train=True)
self.assertEqual(o1.shape, (2, 10))
self.assertEqual(o2.shape, (2, 40, 40, 1))

self.assertEqual(len(net._node_by_depth), 10)


if __name__ == '__main__':

Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ def test_get_layer(self):
except Exception as e:
print(e)

def test_model_weights_copy(self):
print('-' * 20, 'test_model_weights_copy', '-' * 20)
model_basic = basic_static_model()
model_weights = model_basic.trainable_weights
ori_len = len(model_weights)
model_weights.append(np.arange(5))
new_len = len(model_weights)
self.assertEqual(new_len - 1, ori_len)


if __name__ == '__main__':

Expand Down