From 7aad31c908c571a7303cf71e6b4cbc41189c541e Mon Sep 17 00:00:00 2001 From: "christopher.rohkohl" Date: Fri, 7 Aug 2015 15:54:48 +0200 Subject: [PATCH] Fix of windows compatibility issues - binary files need to be opened with 'rb' or 'wb' flag - leveldb package is not available for windows - display warning instead of crashing - adapted find_executable with fallback for windows ".exe" extension and let the version check pass - normalization of path separators to linux style - turn off fcntl on windows --- digits/config/caffe_option.py | 9 ++++++++- digits/dataset/images/classification/job.py | 4 ++-- digits/device_query.py | 11 +++++++++-- digits/model/tasks/caffe_train.py | 6 +++--- digits/task.py | 6 ++++-- digits/utils/__init__.py | 21 +++++++++++++++------ scripts/generate_docs.py | 4 ++-- 7 files changed, 43 insertions(+), 18 deletions(-) diff --git a/digits/config/caffe_option.py b/digits/config/caffe_option.py index 70b8be2cc..2f4056477 100644 --- a/digits/config/caffe_option.py +++ b/digits/config/caffe_option.py @@ -65,6 +65,8 @@ def validate(cls, value): if value == '': # Find the executable executable = cls.find_executable('caffe') + if not executable: + executable = cls.find_executable('caffe.exe') if not executable: raise config_option.BadValue('caffe binary not found in PATH') cls.validate_version(executable) @@ -187,6 +189,9 @@ def get_version(executable): elif platform.system() == 'Darwin': # XXX: guess and let the user figure out errors later return (0,11,0) + elif platform.system() == 'Windows': + # XXX: guess and let the user figure out errors later + return (0,12,0) else: print 'WARNING: platform "%s" not supported' % platform.system() return None @@ -197,6 +202,8 @@ def _set_config_dict_value(self, value): else: if value == '': executable = self.find_executable('caffe') + if not executable: + executable = self.find_executable('caffe.exe') else: executable = os.path.join(value, 'build', 'tools', 'caffe') @@ -238,7 +245,7 @@ def apply(self): print 'Did you forget to "make pycaffe"?' raise - if platform.system() == 'Darwin': + if platform.system() == 'Darwin' or platform.system() == 'Windows': # Strange issue with protocol buffers and pickle - see issue #32 sys.path.insert(0, os.path.join( os.path.dirname(caffe.__file__), 'proto')) diff --git a/digits/dataset/images/classification/job.py b/digits/dataset/images/classification/job.py index 326ac789a..21175a8d5 100644 --- a/digits/dataset/images/classification/job.py +++ b/digits/dataset/images/classification/job.py @@ -37,7 +37,7 @@ def __setstate__(self, state): import numpy as np old_blob = caffe_pb2.BlobProto() - with open(task.path(task.mean_file)) as infile: + with open(task.path(task.mean_file),'rb') as infile: old_blob.ParseFromString(infile.read()) data = np.array(old_blob.data).reshape( old_blob.channels, @@ -48,7 +48,7 @@ def __setstate__(self, state): new_blob.num = 1 new_blob.channels, new_blob.height, new_blob.width = data.shape new_blob.data.extend(data.astype(float).flat) - with open(task.path(task.mean_file), 'w') as outfile: + with open(task.path(task.mean_file), 'wb') as outfile: outfile.write(new_blob.SerializeToString()) else: print '\tSetting "%s" status to ERROR because it was created with RGB channels' % self.name() diff --git a/digits/device_query.py b/digits/device_query.py index 0c90614a2..5ad3d69c4 100755 --- a/digits/device_query.py +++ b/digits/device_query.py @@ -111,6 +111,8 @@ def get_library(name): return ctypes.cdll.LoadLibrary('%s.so' % name) elif platform.system() == 'Darwin': return ctypes.cdll.LoadLibrary('%s.dylib' % name) + elif platform.system() == 'Windows': + return ctypes.windll.LoadLibrary('%s.dll' % name) except OSError: pass return None @@ -133,7 +135,10 @@ def get_devices(force_reload=False): cudart = get_library('libcudart') if cudart is None: - return [] + cudart = get_library('cudart64_75') + if cudart is None: + print 'Failed to locate cudart library' + return [] # check CUDA version cuda_version = ctypes.c_int() @@ -180,7 +185,9 @@ def get_nvml_info(device_id): nvml = get_library('libnvidia-ml') if nvml is None: - return None + nvml = get_library('nvml') + if nvml is None: + return None rc = nvml.nvmlInit() if rc != 0: diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py index 839d01d7d..1c1b56127 100644 --- a/digits/model/tasks/caffe_train.py +++ b/digits/model/tasks/caffe_train.py @@ -233,7 +233,7 @@ def save_files_classification(self): val_data_layer.data_param.backend = caffe_pb2.DataParameter.LMDB if self.use_mean: mean_pixel = None - with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as f: + with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as f: blob = caffe_pb2.BlobProto() blob.MergeFromString(f.read()) mean = np.reshape(blob.data, @@ -1312,7 +1312,7 @@ def get_transformer(self): channel_swap = (2,1,0) if self.use_mean: - with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as infile: + with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as infile: blob = caffe_pb2.BlobProto() blob.MergeFromString(infile.read()) mean_pixel = np.reshape(blob.data, @@ -1331,7 +1331,7 @@ def get_transformer(self): channel_swap = (2,1,0) if self.dataset.mean_file: - with open(self.dataset.path(self.dataset.mean_file)) as infile: + with open(self.dataset.path(self.dataset.mean_file),'rb') as infile: blob = caffe_pb2.BlobProto() blob.MergeFromString(infile.read()) mean_pixel = np.reshape(blob.data, diff --git a/digits/task.py b/digits/task.py index 03d4bbd36..2aaf03c9a 100644 --- a/digits/task.py +++ b/digits/task.py @@ -15,6 +15,8 @@ from config import config_value from status import Status, StatusCls +import platform + # NOTE: Increment this everytime the pickled version changes PICKLE_VERSION = 1 @@ -128,7 +130,7 @@ def path(self, filename, relative=False): path = os.path.join(self.job_dir, filename) if relative: path = os.path.relpath(path, config_value('jobs_dir')) - return str(path) + return str(path).replace("\\","/") def ready_to_queue(self): """ @@ -193,7 +195,7 @@ def run(self, resources): stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=self.job_dir, - close_fds=True, + close_fds=False if platform.system() == 'Windows' else True, ) try: diff --git a/digits/utils/__init__.py b/digits/utils/__init__.py index daac0468e..a272f7941 100644 --- a/digits/utils/__init__.py +++ b/digits/utils/__init__.py @@ -1,13 +1,19 @@ -# Copyright (c) 2014-2015, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2014-2015, NVIDIA CORPORATION. All rights reserved. import os import math -import fcntl import locale from random import uniform from urlparse import urlparse from io import BlockingIOError import inspect +import platform + + +if not platform.system() == 'Windows': + import fcntl +else: + import gevent.os HTTP_TIMEOUT = 6.05 @@ -27,14 +33,18 @@ def nonblocking_readlines(f): Newlines are normalized to the Unix standard. """ fd = f.fileno() - fl = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) + if not platform.system() == 'Windows': + fl = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) enc = locale.getpreferredencoding(False) buf = bytearray() while True: try: - block = os.read(fd, 8192) + if not platform.system() == 'Windows': + block = os.read(fd, 8192) + else: + block = gevent.os.tp_read(fd, 8192) except (BlockingIOError, OSError): yield "" continue @@ -42,7 +52,6 @@ def nonblocking_readlines(f): if not block: if buf: yield buf.decode(enc) - buf.clear() break buf.extend(block) diff --git a/scripts/generate_docs.py b/scripts/generate_docs.py index afe55b8a5..6c286fec3 100755 --- a/scripts/generate_docs.py +++ b/scripts/generate_docs.py @@ -157,10 +157,10 @@ def _print_route(self, route): ) filename = os.path.normpath(route['location']['filename']) if filename.startswith(digits_root): - filename = os.path.relpath(filename, digits_root) + filename = os.path.relpath(filename, digits_root).replace("\\","/") self.w('Location: [`%s@%s`](%s#L%s)' % ( filename, route['location']['line'], - os.path.join('..', filename), route['location']['line'], + os.path.join('..', filename).replace("\\","/"), route['location']['line'], )) self.w()