Skip to content

Commit

Permalink
fix image one sample bug (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Jan 15, 2018
1 parent 9f0872b commit 362dae5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
20 changes: 19 additions & 1 deletion demo/mxnet/mxnet_demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
import mxnet as mx
import logging

import mxnet as mx
Expand All @@ -22,6 +24,8 @@
# scalar0 is used to record scalar metrics while MXNet is training. We will record accuracy.
# In the visualization, we can see the accuracy is increasing as more training steps happen.
scalar0 = logger.scalar("scalars/scalar0")
image0 = logger.image("images/image0", 1)
histogram0 = logger.histogram("histogram/histogram0", num_buckets=100)

# Record training steps
cnt_step = 0
Expand All @@ -42,6 +46,19 @@ def _callback(param):
cnt_step += 1
return _callback

def add_image_histogram():
def _callback(iter_no, sym, arg, aux):
image0.start_sampling()
weight = arg['fullyconnected1_weight'].asnumpy()
shape = [100, 50]
data = weight.flatten()

image0.add_sample(shape, list(data))
histogram0.add_record(iter_no, list(data))

image0.finish_sampling()
return _callback


# Start to build CNN in MXNet, train MNIST dataset. For more info, check MXNet's official website:
# https://mxnet.incubator.apache.org/tutorials/python/mnist.html
Expand Down Expand Up @@ -81,7 +98,8 @@ def _callback(param):
eval_metric='acc',
# integrate our customized callback method
batch_end_callback=[add_scalar()],
num_epoch=2)
epoch_end_callback=[add_image_histogram()],
num_epoch=5)

test_iter = mx.io.NDArrayIter(mnist['test_data'], None, batch_size)
prob = lenet_model.predict(test_iter)
Expand Down
2 changes: 1 addition & 1 deletion visualdl/python/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def scalar(self, tag, type='float'):
}
return type2scalar[type](tag)

def image(self, tag, num_samples, step_cycle):
def image(self, tag, num_samples, step_cycle=1):
"""
Create an image writer that used to write image data.
"""
Expand Down
4 changes: 4 additions & 0 deletions visualdl/server/lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pprint
import re
import sys
import time
import urllib
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -131,6 +132,7 @@ def get_invididual_image(storage, mode, tag, step_index, max_size=80):
with storage.mode(mode) as reader:
res = re.search(r".*/([0-9]+$)", tag)
# remove suffix '/x'
offset = 0
if res:
offset = int(res.groups()[0])
tag = tag[:tag.rfind('/')]
Expand Down Expand Up @@ -206,4 +208,6 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep)

0 comments on commit 362dae5

Please sign in to comment.