From 20acce080cd39151b70263bf36a8b9e32c0ee7b9 Mon Sep 17 00:00:00 2001 From: "christopher.rohkohl" Date: Fri, 7 Aug 2015 15:54:48 +0200 Subject: [PATCH] Fix of windows compatibility issues - binary files need to be opened with 'rb' or 'wb' flag - leveldb package is not available for windows - display warning instead of crashing - adapted find_executable with fallback for windows ".exe" extension and let the version check pass - normalization of path separators to linux style - turn off fcntl on windows - Making tests pass and support of file upload - portable cudart localization - Make tests pass + refactoring of existing modifications --- digits/config/caffe_option.py | 9 ++++- digits/dataset/images/classification/job.py | 4 +- .../images/generic/test_lmdb_creator.py | 2 +- digits/device_query.py | 18 ++++++++- digits/job.py | 2 +- .../model/images/classification/test_views.py | 4 +- digits/model/images/classification/views.py | 15 ++++++-- digits/model/images/generic/test_views.py | 4 +- digits/model/images/generic/views.py | 15 ++++++-- digits/model/tasks/caffe_train.py | 6 +-- digits/task.py | 6 ++- digits/utils/__init__.py | 21 ++++++++--- digits/utils/test_image.py | 37 ++++++++++++++++--- docs/API.md | 14 +++---- docs/FlaskRoutes.md | 24 ++++++------ examples/classification/example.py | 2 +- scripts/generate_docs.py | 4 +- scripts/test_generate_docs.py | 9 ++++- tools/create_db.py | 2 +- tools/test_create_db.py | 6 +-- tools/test_parse_folder.py | 20 +++++++++- 21 files changed, 161 insertions(+), 63 deletions(-) diff --git a/digits/config/caffe_option.py b/digits/config/caffe_option.py index 70b8be2cc..3779bc193 100644 --- a/digits/config/caffe_option.py +++ b/digits/config/caffe_option.py @@ -65,6 +65,8 @@ def validate(cls, value): if value == '': # Find the executable executable = cls.find_executable('caffe') + if not executable: + executable = cls.find_executable('caffe.exe') if not executable: raise config_option.BadValue('caffe binary not found in PATH') cls.validate_version(executable) @@ -187,6 +189,9 @@ def get_version(executable): elif platform.system() == 'Darwin': # XXX: guess and let the user figure out errors later return (0,11,0) + elif platform.system() == 'Windows': + # XXX: guess and let the user figure out errors later + return (0,11,0) else: print 'WARNING: platform "%s" not supported' % platform.system() return None @@ -197,6 +202,8 @@ def _set_config_dict_value(self, value): else: if value == '': executable = self.find_executable('caffe') + if not executable: + executable = self.find_executable('caffe.exe') else: executable = os.path.join(value, 'build', 'tools', 'caffe') @@ -238,7 +245,7 @@ def apply(self): print 'Did you forget to "make pycaffe"?' raise - if platform.system() == 'Darwin': + if platform.system() == 'Darwin' or platform.system() == 'Windows': # Strange issue with protocol buffers and pickle - see issue #32 sys.path.insert(0, os.path.join( os.path.dirname(caffe.__file__), 'proto')) diff --git a/digits/dataset/images/classification/job.py b/digits/dataset/images/classification/job.py index 326ac789a..21175a8d5 100644 --- a/digits/dataset/images/classification/job.py +++ b/digits/dataset/images/classification/job.py @@ -37,7 +37,7 @@ def __setstate__(self, state): import numpy as np old_blob = caffe_pb2.BlobProto() - with open(task.path(task.mean_file)) as infile: + with open(task.path(task.mean_file),'rb') as infile: old_blob.ParseFromString(infile.read()) data = np.array(old_blob.data).reshape( old_blob.channels, @@ -48,7 +48,7 @@ def __setstate__(self, state): new_blob.num = 1 new_blob.channels, new_blob.height, new_blob.width = data.shape new_blob.data.extend(data.astype(float).flat) - with open(task.path(task.mean_file), 'w') as outfile: + with open(task.path(task.mean_file), 'wb') as outfile: outfile.write(new_blob.SerializeToString()) else: print '\tSetting "%s" status to ERROR because it was created with RGB channels' % self.name() diff --git a/digits/dataset/images/generic/test_lmdb_creator.py b/digits/dataset/images/generic/test_lmdb_creator.py index eed80fc91..2b372ce1c 100755 --- a/digits/dataset/images/generic/test_lmdb_creator.py +++ b/digits/dataset/images/generic/test_lmdb_creator.py @@ -158,7 +158,7 @@ def _save_mean(mean, filename): blob.channels = 1 blob.height, blob.width = mean.shape blob.data.extend(mean.astype(float).flat) - with open(filename, 'w') as outfile: + with open(filename, 'wb') as outfile: outfile.write(blob.SerializeToString()) elif filename.endswith(('.jpg', '.jpeg', '.png')): diff --git a/digits/device_query.py b/digits/device_query.py index 0c90614a2..31a1e9bd0 100755 --- a/digits/device_query.py +++ b/digits/device_query.py @@ -111,12 +111,24 @@ def get_library(name): return ctypes.cdll.LoadLibrary('%s.so' % name) elif platform.system() == 'Darwin': return ctypes.cdll.LoadLibrary('%s.dylib' % name) + elif platform.system() == 'Windows': + return ctypes.windll.LoadLibrary('%s.dll' % name) except OSError: pass return None devices = None +def get_cudart(): + if not platform.system() == 'Windows': + return get_library('libcudart') + + arch = platform.architecture()[0] + for ver in range(90,50,-5): + cudart = get_library('cudart%s_%d' % (arch[:2], ver)) + if not cudart is None: + return cudart + def get_devices(force_reload=False): """ Returns a list of c_cudaDeviceProp's @@ -131,7 +143,7 @@ def get_devices(force_reload=False): return devices devices = [] - cudart = get_library('libcudart') + cudart = get_cudart() if cudart is None: return [] @@ -180,7 +192,9 @@ def get_nvml_info(device_id): nvml = get_library('libnvidia-ml') if nvml is None: - return None + nvml = get_library('nvml') + if nvml is None: + return None rc = nvml.nvmlInit() if rc != 0: diff --git a/digits/job.py b/digits/job.py index 7283ea1a2..be8a69780 100644 --- a/digits/job.py +++ b/digits/job.py @@ -118,7 +118,7 @@ def path(self, filename, relative=False): path = os.path.join(self._dir, filename) if relative: path = os.path.relpath(path, config_value('jobs_dir')) - return str(path) + return str(path).replace("\\","/") def path_is_local(self, path): """assert that a path is local to _dir""" diff --git a/digits/model/images/classification/test_views.py b/digits/model/images/classification/test_views.py index abcebc690..d0bd09e89 100644 --- a/digits/model/images/classification/test_views.py +++ b/digits/model/images/classification/test_views.py @@ -338,7 +338,7 @@ def test_classify_one(self): category = self.imageset_paths.keys()[0] image_path = self.imageset_paths[category][0] image_path = os.path.join(self.imageset_folder, image_path) - with open(image_path) as infile: + with open(image_path,'rb') as infile: # StringIO wrapping is needed to simulate POST file upload. image_upload = (StringIO(infile.read()), 'image.png') @@ -360,7 +360,7 @@ def test_classify_one_json(self): category = self.imageset_paths.keys()[0] image_path = self.imageset_paths[category][0] image_path = os.path.join(self.imageset_folder, image_path) - with open(image_path) as infile: + with open(image_path,'rb') as infile: # StringIO wrapping is needed to simulate POST file upload. image_upload = (StringIO(infile.read()), 'image.png') diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index f5c1bc739..85f5342e1 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -25,6 +25,7 @@ from forms import ImageClassificationModelForm from job import ImageClassificationModelJob from digits.status import Status +import platform NAMESPACE = '/models/images/classification' @@ -246,9 +247,17 @@ def image_classification_model_classify_one(): if 'image_url' in flask.request.form and flask.request.form['image_url']: image = utils.image.load_image(flask.request.form['image_url']) elif 'image_file' in flask.request.files and flask.request.files['image_file']: - with tempfile.NamedTemporaryFile() as outfile: - flask.request.files['image_file'].save(outfile.name) - image = utils.image.load_image(outfile.name) + if not platform.system() == 'Windows': + with tempfile.NamedTemporaryFile() as outfile: + flask.request.files['image_file'].save(outfile.name) + image = utils.image.load_image(outfile.name) + else: + # prevent temporary file permission errors + outfile = tempfile.mkstemp(suffix='.bin') + flask.request.files['image_file'].save(outfile[1]) + image = utils.image.load_image(outfile[1]) + os.close(outfile[0]) + os.remove(outfile[1]) else: raise werkzeug.exceptions.BadRequest('must provide image_url or image_file') diff --git a/digits/model/images/generic/test_views.py b/digits/model/images/generic/test_views.py index 3116f0eb6..1ffff6f59 100644 --- a/digits/model/images/generic/test_views.py +++ b/digits/model/images/generic/test_views.py @@ -338,7 +338,7 @@ def test_model_json(self): def test_infer_one(self): image_path = os.path.join(self.imageset_folder, self.test_image) - with open(image_path) as infile: + with open(image_path,'rb') as infile: # StringIO wrapping is needed to simulate POST file upload. image_upload = (StringIO(infile.read()), 'image.png') @@ -355,7 +355,7 @@ def test_infer_one(self): def test_infer_one_json(self): image_path = os.path.join(self.imageset_folder, self.test_image) - with open(image_path) as infile: + with open(image_path,'rb') as infile: # StringIO wrapping is needed to simulate POST file upload. image_upload = (StringIO(infile.read()), 'image.png') diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py index 86d926e92..b167a8a86 100644 --- a/digits/model/images/generic/views.py +++ b/digits/model/images/generic/views.py @@ -22,6 +22,7 @@ from forms import GenericImageModelForm from job import GenericImageModelJob from digits.status import Status +import platform NAMESPACE = '/models/images/generic' @@ -222,9 +223,17 @@ def generic_image_model_infer_one(): if 'image_url' in flask.request.form and flask.request.form['image_url']: image = utils.image.load_image(flask.request.form['image_url']) elif 'image_file' in flask.request.files and flask.request.files['image_file']: - with tempfile.NamedTemporaryFile() as outfile: - flask.request.files['image_file'].save(outfile.name) - image = utils.image.load_image(outfile.name) + if not platform.system() == 'Windows': + with tempfile.NamedTemporaryFile() as outfile: + flask.request.files['image_file'].save(outfile.name) + image = utils.image.load_image(outfile.name) + else: + # prevent temporary file permission errors + outfile = tempfile.mkstemp(suffix='.bin') + flask.request.files['image_file'].save(outfile[1]) + image = utils.image.load_image(outfile[1]) + os.close(outfile[0]) + os.remove(outfile[1]) else: raise werkzeug.exceptions.BadRequest('must provide image_url or image_file') diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py index 839d01d7d..1c1b56127 100644 --- a/digits/model/tasks/caffe_train.py +++ b/digits/model/tasks/caffe_train.py @@ -233,7 +233,7 @@ def save_files_classification(self): val_data_layer.data_param.backend = caffe_pb2.DataParameter.LMDB if self.use_mean: mean_pixel = None - with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as f: + with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as f: blob = caffe_pb2.BlobProto() blob.MergeFromString(f.read()) mean = np.reshape(blob.data, @@ -1312,7 +1312,7 @@ def get_transformer(self): channel_swap = (2,1,0) if self.use_mean: - with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as infile: + with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as infile: blob = caffe_pb2.BlobProto() blob.MergeFromString(infile.read()) mean_pixel = np.reshape(blob.data, @@ -1331,7 +1331,7 @@ def get_transformer(self): channel_swap = (2,1,0) if self.dataset.mean_file: - with open(self.dataset.path(self.dataset.mean_file)) as infile: + with open(self.dataset.path(self.dataset.mean_file),'rb') as infile: blob = caffe_pb2.BlobProto() blob.MergeFromString(infile.read()) mean_pixel = np.reshape(blob.data, diff --git a/digits/task.py b/digits/task.py index 03d4bbd36..2aaf03c9a 100644 --- a/digits/task.py +++ b/digits/task.py @@ -15,6 +15,8 @@ from config import config_value from status import Status, StatusCls +import platform + # NOTE: Increment this everytime the pickled version changes PICKLE_VERSION = 1 @@ -128,7 +130,7 @@ def path(self, filename, relative=False): path = os.path.join(self.job_dir, filename) if relative: path = os.path.relpath(path, config_value('jobs_dir')) - return str(path) + return str(path).replace("\\","/") def ready_to_queue(self): """ @@ -193,7 +195,7 @@ def run(self, resources): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self.job_dir, - close_fds=True, + close_fds=False if platform.system() == 'Windows' else True, ) try: diff --git a/digits/utils/__init__.py b/digits/utils/__init__.py index daac0468e..50286f204 100644 --- a/digits/utils/__init__.py +++ b/digits/utils/__init__.py @@ -2,17 +2,23 @@ import os import math -import fcntl import locale from random import uniform from urlparse import urlparse from io import BlockingIOError import inspect +import platform + + +if not platform.system() == 'Windows': + import fcntl +else: + import gevent.os HTTP_TIMEOUT = 6.05 def is_url(url): - return url is not None and urlparse(url).scheme != "" + return url is not None and urlparse(url).scheme != "" and not os.path.exists(url) def wait_time(): """Wait a random number of seconds""" @@ -27,14 +33,18 @@ def nonblocking_readlines(f): Newlines are normalized to the Unix standard. """ fd = f.fileno() - fl = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) + if not platform.system() == 'Windows': + fl = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) enc = locale.getpreferredencoding(False) buf = bytearray() while True: try: - block = os.read(fd, 8192) + if not platform.system() == 'Windows': + block = os.read(fd, 8192) + else: + block = gevent.os.tp_read(fd, 8192) except (BlockingIOError, OSError): yield "" continue @@ -42,7 +52,6 @@ def nonblocking_readlines(f): if not block: if buf: yield buf.decode(enc) - buf.clear() break buf.extend(block) diff --git a/digits/utils/test_image.py b/digits/utils/test_image.py index 94fe3690a..cada76579 100644 --- a/digits/utils/test_image.py +++ b/digits/utils/test_image.py @@ -7,6 +7,8 @@ import mock import PIL.Image import numpy as np +import os +import platform from . import image as _, errors @@ -52,9 +54,24 @@ def check_good_file(self, args): orig_mode, suffix, pixel, new_mode = args orig = PIL.Image.new(orig_mode, (10,10), pixel) - with tempfile.NamedTemporaryFile(suffix='.' + suffix) as tmp: - orig.save(tmp.name) - new = _.load_image(tmp.name) + + if not platform.system() == 'Windows': + with tempfile.NamedTemporaryFile(suffix='.' + suffix) as tmp: + orig.save(tmp.name) + new = _.load_image(tmp.name) + else: + # temp files cause permission errors so just generate the name + tmp = tempfile.mkstemp(suffix='.' + suffix) + orig.save(tmp[1]) + new = _.load_image(tmp[1]) + try: + # sometimes on windows the file is not closed yet + # which can cause an exception + os.close(tmp[0]) + os.remove(tmp[1]) + except: + pass + assert new is not None, 'load_image should never return None' assert new.mode == new_mode, 'Image mode should be "%s", not "%s\nargs - %s' % (new_mode, new.mode, args) @@ -94,16 +111,26 @@ def test_corrupted_file(self): corrupted = encoded[:size/2] + encoded[size/2:][::-1] # Save the corrupted image to a temporary file. - f = tempfile.NamedTemporaryFile(delete=False) + if not platform.system() == 'Windows': + f = tempfile.NamedTemporaryFile(delete=False) + fname = f.name + else: + # prevent temporary file permissions error on windows + fname = tempfile.mkstemp(suffix='.bin') + f = os.fdopen(fname[0],'wb') + fname = fname[1] + f.write(corrupted) f.close() assert_raises( errors.LoadImageError, _.load_image, - f.name, + fname, ) + os.remove(fname) + class TestResizeImage(): @classmethod diff --git a/docs/API.md b/docs/API.md index e41d7a4cf..4c99d67d5 100644 --- a/docs/API.md +++ b/docs/API.md @@ -1,6 +1,6 @@ # REST API -*Generated Aug 10, 2015* +*Generated Aug 13, 2015* DIGITS exposes its internal functionality through a REST API. You can access these endpoints by performing a GET or POST on the route, and a JSON object will be returned. @@ -94,7 +94,7 @@ Location: [`digits/model/views.py@31`](../digits/model/views.py#L31) Methods: **POST** -Location: [`digits/model/images/classification/views.py@53`](../digits/model/images/classification/views.py#L53) +Location: [`digits/model/images/classification/views.py@54`](../digits/model/images/classification/views.py#L54) ### `/models/images/classification/classify_many.json` @@ -106,7 +106,7 @@ Location: [`digits/model/images/classification/views.py@53`](../digits/model/ima Methods: **POST** -Location: [`digits/model/images/classification/views.py@290`](../digits/model/images/classification/views.py#L290) +Location: [`digits/model/images/classification/views.py@299`](../digits/model/images/classification/views.py#L299) ### `/models/images/classification/classify_one.json` @@ -118,7 +118,7 @@ Location: [`digits/model/images/classification/views.py@290`](../digits/model/im Methods: **POST** -Location: [`digits/model/images/classification/views.py@236`](../digits/model/images/classification/views.py#L236) +Location: [`digits/model/images/classification/views.py@237`](../digits/model/images/classification/views.py#L237) ### `/models/images/generic.json` @@ -130,7 +130,7 @@ Location: [`digits/model/images/classification/views.py@236`](../digits/model/im Methods: **POST** -Location: [`digits/model/images/generic/views.py@49`](../digits/model/images/generic/views.py#L49) +Location: [`digits/model/images/generic/views.py@50`](../digits/model/images/generic/views.py#L50) ### `/models/images/generic/infer_many.json` @@ -138,7 +138,7 @@ Location: [`digits/model/images/generic/views.py@49`](../digits/model/images/gen Methods: **POST** -Location: [`digits/model/images/generic/views.py@264`](../digits/model/images/generic/views.py#L264) +Location: [`digits/model/images/generic/views.py@273`](../digits/model/images/generic/views.py#L273) ### `/models/images/generic/infer_one.json` @@ -146,5 +146,5 @@ Location: [`digits/model/images/generic/views.py@264`](../digits/model/images/ge Methods: **POST** -Location: [`digits/model/images/generic/views.py@214`](../digits/model/images/generic/views.py#L214) +Location: [`digits/model/images/generic/views.py@215`](../digits/model/images/generic/views.py#L215) diff --git a/docs/FlaskRoutes.md b/docs/FlaskRoutes.md index 65bceabe9..529d5566a 100644 --- a/docs/FlaskRoutes.md +++ b/docs/FlaskRoutes.md @@ -1,6 +1,6 @@ # Flask Routes -*Generated Aug 10, 2015* +*Generated Aug 13, 2015* Documentation on the various routes used internally for the web application. @@ -288,7 +288,7 @@ Location: [`digits/model/views.py@55`](../digits/model/views.py#L55) Methods: **POST** -Location: [`digits/model/images/classification/views.py@53`](../digits/model/images/classification/views.py#L53) +Location: [`digits/model/images/classification/views.py@54`](../digits/model/images/classification/views.py#L54) ### `/models/images/classification/classify_many` @@ -300,7 +300,7 @@ Location: [`digits/model/images/classification/views.py@53`](../digits/model/ima Methods: **GET**, **POST** -Location: [`digits/model/images/classification/views.py@290`](../digits/model/images/classification/views.py#L290) +Location: [`digits/model/images/classification/views.py@299`](../digits/model/images/classification/views.py#L299) ### `/models/images/classification/classify_one` @@ -312,7 +312,7 @@ Location: [`digits/model/images/classification/views.py@290`](../digits/model/im Methods: **GET**, **POST** -Location: [`digits/model/images/classification/views.py@236`](../digits/model/images/classification/views.py#L236) +Location: [`digits/model/images/classification/views.py@237`](../digits/model/images/classification/views.py#L237) ### `/models/images/classification/large_graph` @@ -320,7 +320,7 @@ Location: [`digits/model/images/classification/views.py@236`](../digits/model/im Methods: **GET** -Location: [`digits/model/images/classification/views.py@225`](../digits/model/images/classification/views.py#L225) +Location: [`digits/model/images/classification/views.py@226`](../digits/model/images/classification/views.py#L226) ### `/models/images/classification/new` @@ -328,7 +328,7 @@ Location: [`digits/model/images/classification/views.py@225`](../digits/model/im Methods: **GET** -Location: [`digits/model/images/classification/views.py@32`](../digits/model/images/classification/views.py#L32) +Location: [`digits/model/images/classification/views.py@33`](../digits/model/images/classification/views.py#L33) ### `/models/images/classification/top_n` @@ -336,7 +336,7 @@ Location: [`digits/model/images/classification/views.py@32`](../digits/model/ima Methods: **POST** -Location: [`digits/model/images/classification/views.py@374`](../digits/model/images/classification/views.py#L374) +Location: [`digits/model/images/classification/views.py@383`](../digits/model/images/classification/views.py#L383) ### `/models/images/generic` @@ -348,7 +348,7 @@ Location: [`digits/model/images/classification/views.py@374`](../digits/model/im Methods: **POST** -Location: [`digits/model/images/generic/views.py@49`](../digits/model/images/generic/views.py#L49) +Location: [`digits/model/images/generic/views.py@50`](../digits/model/images/generic/views.py#L50) ### `/models/images/generic/infer_many` @@ -356,7 +356,7 @@ Location: [`digits/model/images/generic/views.py@49`](../digits/model/images/gen Methods: **GET**, **POST** -Location: [`digits/model/images/generic/views.py@264`](../digits/model/images/generic/views.py#L264) +Location: [`digits/model/images/generic/views.py@273`](../digits/model/images/generic/views.py#L273) ### `/models/images/generic/infer_one` @@ -364,7 +364,7 @@ Location: [`digits/model/images/generic/views.py@264`](../digits/model/images/ge Methods: **GET**, **POST** -Location: [`digits/model/images/generic/views.py@214`](../digits/model/images/generic/views.py#L214) +Location: [`digits/model/images/generic/views.py@215`](../digits/model/images/generic/views.py#L215) ### `/models/images/generic/large_graph` @@ -372,7 +372,7 @@ Location: [`digits/model/images/generic/views.py@214`](../digits/model/images/ge Methods: **GET** -Location: [`digits/model/images/generic/views.py@203`](../digits/model/images/generic/views.py#L203) +Location: [`digits/model/images/generic/views.py@204`](../digits/model/images/generic/views.py#L204) ### `/models/images/generic/new` @@ -380,7 +380,7 @@ Location: [`digits/model/images/generic/views.py@203`](../digits/model/images/ge Methods: **GET** -Location: [`digits/model/images/generic/views.py@29`](../digits/model/images/generic/views.py#L29) +Location: [`digits/model/images/generic/views.py@30`](../digits/model/images/generic/views.py#L30) ### `/models/visualize-lr` diff --git a/examples/classification/example.py b/examples/classification/example.py index 3cd339278..00763cf72 100755 --- a/examples/classification/example.py +++ b/examples/classification/example.py @@ -65,7 +65,7 @@ def get_transformer(deploy_file, mean_file=None): if mean_file: # set mean pixel - with open(mean_file) as infile: + with open(mean_file,'rb') as infile: blob = caffe_pb2.BlobProto() blob.MergeFromString(infile.read()) if blob.HasField('shape'): diff --git a/scripts/generate_docs.py b/scripts/generate_docs.py index afe55b8a5..6c286fec3 100755 --- a/scripts/generate_docs.py +++ b/scripts/generate_docs.py @@ -157,10 +157,10 @@ def _print_route(self, route): ) filename = os.path.normpath(route['location']['filename']) if filename.startswith(digits_root): - filename = os.path.relpath(filename, digits_root) + filename = os.path.relpath(filename, digits_root).replace("\\","/") self.w('Location: [`%s@%s`](%s#L%s)' % ( filename, route['location']['line'], - os.path.join('..', filename), route['location']['line'], + os.path.join('..', filename).replace("\\","/"), route['location']['line'], )) self.w() diff --git a/scripts/test_generate_docs.py b/scripts/test_generate_docs.py index dcc5d737a..8c7e2af20 100644 --- a/scripts/test_generate_docs.py +++ b/scripts/test_generate_docs.py @@ -5,6 +5,7 @@ import tempfile import itertools import unittest +import platform try: import flask.ext.autodoc @@ -25,8 +26,11 @@ def check_doc_file(generator, doc_filename): """ Checks that the output generated by generator matches the contents of doc_filename """ - with tempfile.NamedTemporaryFile(suffix='.md') as tmp_file: - generator.generate(tmp_file.name) + # overcome temporary file permission errors + tmp_file_name = tempfile.mkstemp(suffix='.md') + os.close(tmp_file_name[0]) + with open(tmp_file_name[1],'w+') as tmp_file: + generator.generate(tmp_file_name[1]) tmp_file.seek(0) with open(doc_filename) as doc_file: # memory friendly @@ -41,6 +45,7 @@ def check_doc_file(generator, doc_filename): print '(Previous)', doc_line print '(New) ', tmp_line raise RuntimeError('%s needs to be regenerated. Use scripts/generate_docs.py' % doc_filename) + os.remove(tmp_file_name[1]) def test_api_md(): with app.app_context(): diff --git a/tools/create_db.py b/tools/create_db.py index 55af4dbc8..05c5f75a4 100755 --- a/tools/create_db.py +++ b/tools/create_db.py @@ -281,7 +281,7 @@ def create(self, input_file, width, height, blob.channels, blob.height, blob.width = data.shape blob.data.extend(data.astype(float).flat) - with open(mean_file, 'w') as outfile: + with open(mean_file, 'wb') as outfile: outfile.write(blob.SerializeToString()) elif mean_file.lower().endswith(('.jpg', '.jpeg', '.png')): image = PIL.Image.fromarray(mean) diff --git a/tools/test_create_db.py b/tools/test_create_db.py index 4e7a8f063..c32e6bc82 100644 --- a/tools/test_create_db.py +++ b/tools/test_create_db.py @@ -36,7 +36,7 @@ def setUpClass(cls): os.close(fd) # Use the example picture to construct a test input file - with open(cls.input_file, 'w') as f: + with open(cls.input_file, 'wb') as f: f.write('digits/static/images/mona_lisa.jpg 0') @classmethod @@ -96,7 +96,7 @@ def setUpClass(cls): cls.db_name = tempfile.mkdtemp(dir=cls.tmpdir) cls.db = _.DbCreator(cls.db_name) _handle, cls.image_path = tempfile.mkstemp(dir=cls.tmpdir, suffix='.jpg') - with open(cls.image_path, 'w') as outfile: + with open(cls.image_path, 'wb') as outfile: PIL.Image.fromarray(np.zeros((10,10,3),dtype=np.uint8)).save(outfile, format='JPEG', quality=100) @classmethod @@ -151,7 +151,7 @@ def test_set_mapsize(self): # create textfile fd, input_file = tempfile.mkstemp() os.close(fd) - with open(input_file, 'w') as f: + with open(input_file, 'wb') as f: f.write('digits/static/images/mona_lisa.jpg 0') # create DbCreator object diff --git a/tools/test_parse_folder.py b/tools/test_parse_folder.py index 6122de031..600fdea2b 100644 --- a/tools/test_parse_folder.py +++ b/tools/test_parse_folder.py @@ -4,6 +4,7 @@ import tempfile import shutil import itertools +import platform from nose.tools import raises, assert_raises import mock @@ -25,7 +26,15 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - shutil.rmtree(cls.tmpdir) + if platform.system() is not 'Windows': + shutil.rmtree(cls.tmpdir) + else: + # there is a temp-file permission problem which might + # prevent removal of the data in windows + try: + shutil.rmtree(cls.tmpdir) + except: + pass def test_dir(self): assert _.validate_folder(self.tmpdir) == True @@ -47,7 +56,14 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - shutil.rmtree(cls.tmpdir) + if not platform.system() == 'Windows': + shutil.rmtree(cls.tmpdir) + else: + # prevent temporary file permissions error + try: + shutil.rmtree(cls.tmpdir) + except: + pass def test_missing_file(self): assert _.validate_output_file(None) == True, 'all new files should be valid'