Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows Compatibility #199

Merged
merged 1 commit into from
Aug 13, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion digits/config/caffe_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def validate(cls, value):
if value == '<PATHS>':
# Find the executable
executable = cls.find_executable('caffe')
if not executable:
executable = cls.find_executable('caffe.exe')
if not executable:
raise config_option.BadValue('caffe binary not found in PATH')
cls.validate_version(executable)
Expand Down Expand Up @@ -187,6 +189,9 @@ def get_version(executable):
elif platform.system() == 'Darwin':
# XXX: guess and let the user figure out errors later
return (0,11,0)
elif platform.system() == 'Windows':
# XXX: guess and let the user figure out errors later
return (0,11,0)
else:
print 'WARNING: platform "%s" not supported' % platform.system()
return None
Expand All @@ -197,6 +202,8 @@ def _set_config_dict_value(self, value):
else:
if value == '<PATHS>':
executable = self.find_executable('caffe')
if not executable:
executable = self.find_executable('caffe.exe')
else:
executable = os.path.join(value, 'build', 'tools', 'caffe')

Expand Down Expand Up @@ -238,7 +245,7 @@ def apply(self):
print 'Did you forget to "make pycaffe"?'
raise

if platform.system() == 'Darwin':
if platform.system() == 'Darwin' or platform.system() == 'Windows':
# Strange issue with protocol buffers and pickle - see issue #32
sys.path.insert(0, os.path.join(
os.path.dirname(caffe.__file__), 'proto'))
Expand Down
4 changes: 2 additions & 2 deletions digits/dataset/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __setstate__(self, state):
import numpy as np

old_blob = caffe_pb2.BlobProto()
with open(task.path(task.mean_file)) as infile:
with open(task.path(task.mean_file),'rb') as infile:
old_blob.ParseFromString(infile.read())
data = np.array(old_blob.data).reshape(
old_blob.channels,
Expand All @@ -48,7 +48,7 @@ def __setstate__(self, state):
new_blob.num = 1
new_blob.channels, new_blob.height, new_blob.width = data.shape
new_blob.data.extend(data.astype(float).flat)
with open(task.path(task.mean_file), 'w') as outfile:
with open(task.path(task.mean_file), 'wb') as outfile:
outfile.write(new_blob.SerializeToString())
else:
print '\tSetting "%s" status to ERROR because it was created with RGB channels' % self.name()
Expand Down
2 changes: 1 addition & 1 deletion digits/dataset/images/generic/test_lmdb_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _save_mean(mean, filename):
blob.channels = 1
blob.height, blob.width = mean.shape
blob.data.extend(mean.astype(float).flat)
with open(filename, 'w') as outfile:
with open(filename, 'wb') as outfile:
outfile.write(blob.SerializeToString())

elif filename.endswith(('.jpg', '.jpeg', '.png')):
Expand Down
18 changes: 16 additions & 2 deletions digits/device_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,24 @@ def get_library(name):
return ctypes.cdll.LoadLibrary('%s.so' % name)
elif platform.system() == 'Darwin':
return ctypes.cdll.LoadLibrary('%s.dylib' % name)
elif platform.system() == 'Windows':
return ctypes.windll.LoadLibrary('%s.dll' % name)
except OSError:
pass
return None

devices = None

def get_cudart():
if not platform.system() == 'Windows':
return get_library('libcudart')

arch = platform.architecture()[0]
for ver in range(90,50,-5):
cudart = get_library('cudart%s_%d' % (arch[:2], ver))
if not cudart is None:
return cudart

def get_devices(force_reload=False):
"""
Returns a list of c_cudaDeviceProp's
Expand All @@ -131,7 +143,7 @@ def get_devices(force_reload=False):
return devices
devices = []

cudart = get_library('libcudart')
cudart = get_cudart()
if cudart is None:
return []

Expand Down Expand Up @@ -180,7 +192,9 @@ def get_nvml_info(device_id):

nvml = get_library('libnvidia-ml')
if nvml is None:
return None
nvml = get_library('nvml')
if nvml is None:
return None

rc = nvml.nvmlInit()
if rc != 0:
Expand Down
2 changes: 1 addition & 1 deletion digits/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def path(self, filename, relative=False):
path = os.path.join(self._dir, filename)
if relative:
path = os.path.relpath(path, config_value('jobs_dir'))
return str(path)
return str(path).replace("\\","/")

def path_is_local(self, path):
"""assert that a path is local to _dir"""
Expand Down
4 changes: 2 additions & 2 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_classify_one(self):
category = self.imageset_paths.keys()[0]
image_path = self.imageset_paths[category][0]
image_path = os.path.join(self.imageset_folder, image_path)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand All @@ -360,7 +360,7 @@ def test_classify_one_json(self):
category = self.imageset_paths.keys()[0]
image_path = self.imageset_paths[category][0]
image_path = os.path.join(self.imageset_folder, image_path)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand Down
9 changes: 6 additions & 3 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from forms import ImageClassificationModelForm
from job import ImageClassificationModelJob
from digits.status import Status
import platform

NAMESPACE = '/models/images/classification'

Expand Down Expand Up @@ -246,9 +247,11 @@ def image_classification_model_classify_one():
if 'image_url' in flask.request.form and flask.request.form['image_url']:
image = utils.image.load_image(flask.request.form['image_url'])
elif 'image_file' in flask.request.files and flask.request.files['image_file']:
with tempfile.NamedTemporaryFile() as outfile:
flask.request.files['image_file'].save(outfile.name)
image = utils.image.load_image(outfile.name)
outfile = tempfile.mkstemp(suffix='.bin')
flask.request.files['image_file'].save(outfile[1])
image = utils.image.load_image(outfile[1])
os.close(outfile[0])
os.remove(outfile[1])
else:
raise werkzeug.exceptions.BadRequest('must provide image_url or image_file')

Expand Down
4 changes: 2 additions & 2 deletions digits/model/images/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def test_model_json(self):

def test_infer_one(self):
image_path = os.path.join(self.imageset_folder, self.test_image)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand All @@ -355,7 +355,7 @@ def test_infer_one(self):

def test_infer_one_json(self):
image_path = os.path.join(self.imageset_folder, self.test_image)
with open(image_path) as infile:
with open(image_path,'rb') as infile:
# StringIO wrapping is needed to simulate POST file upload.
image_upload = (StringIO(infile.read()), 'image.png')

Expand Down
9 changes: 6 additions & 3 deletions digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from forms import GenericImageModelForm
from job import GenericImageModelJob
from digits.status import Status
import platform

NAMESPACE = '/models/images/generic'

Expand Down Expand Up @@ -222,9 +223,11 @@ def generic_image_model_infer_one():
if 'image_url' in flask.request.form and flask.request.form['image_url']:
image = utils.image.load_image(flask.request.form['image_url'])
elif 'image_file' in flask.request.files and flask.request.files['image_file']:
with tempfile.NamedTemporaryFile() as outfile:
flask.request.files['image_file'].save(outfile.name)
image = utils.image.load_image(outfile.name)
outfile = tempfile.mkstemp(suffix='.bin')
flask.request.files['image_file'].save(outfile[1])
image = utils.image.load_image(outfile[1])
os.close(outfile[0])
os.remove(outfile[1])
else:
raise werkzeug.exceptions.BadRequest('must provide image_url or image_file')

Expand Down
6 changes: 3 additions & 3 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def save_files_classification(self):
val_data_layer.data_param.backend = caffe_pb2.DataParameter.LMDB
if self.use_mean:
mean_pixel = None
with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as f:
with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as f:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(f.read())
mean = np.reshape(blob.data,
Expand Down Expand Up @@ -1312,7 +1312,7 @@ def get_transformer(self):
channel_swap = (2,1,0)

if self.use_mean:
with open(self.dataset.path(self.dataset.train_db_task().mean_file)) as infile:
with open(self.dataset.path(self.dataset.train_db_task().mean_file),'rb') as infile:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(infile.read())
mean_pixel = np.reshape(blob.data,
Expand All @@ -1331,7 +1331,7 @@ def get_transformer(self):
channel_swap = (2,1,0)

if self.dataset.mean_file:
with open(self.dataset.path(self.dataset.mean_file)) as infile:
with open(self.dataset.path(self.dataset.mean_file),'rb') as infile:
blob = caffe_pb2.BlobProto()
blob.MergeFromString(infile.read())
mean_pixel = np.reshape(blob.data,
Expand Down
6 changes: 4 additions & 2 deletions digits/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from config import config_value
from status import Status, StatusCls

import platform

# NOTE: Increment this everytime the pickled version changes
PICKLE_VERSION = 1

Expand Down Expand Up @@ -128,7 +130,7 @@ def path(self, filename, relative=False):
path = os.path.join(self.job_dir, filename)
if relative:
path = os.path.relpath(path, config_value('jobs_dir'))
return str(path)
return str(path).replace("\\","/")

def ready_to_queue(self):
"""
Expand Down Expand Up @@ -193,7 +195,7 @@ def run(self, resources):
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=self.job_dir,
close_fds=True,
close_fds=False if platform.system() == 'Windows' else True,
)

try:
Expand Down
21 changes: 15 additions & 6 deletions digits/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

import os
import math
import fcntl
import locale
from random import uniform
from urlparse import urlparse
from io import BlockingIOError
import inspect
import platform


if not platform.system() == 'Windows':
import fcntl
else:
import gevent.os

HTTP_TIMEOUT = 6.05

def is_url(url):
return url is not None and urlparse(url).scheme != ""
return url is not None and urlparse(url).scheme != "" and not os.path.exists(url)

def wait_time():
"""Wait a random number of seconds"""
Expand All @@ -27,22 +33,25 @@ def nonblocking_readlines(f):
Newlines are normalized to the Unix standard.
"""
fd = f.fileno()
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
if not platform.system() == 'Windows':
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
enc = locale.getpreferredencoding(False)

buf = bytearray()
while True:
try:
block = os.read(fd, 8192)
if not platform.system() == 'Windows':
block = os.read(fd, 8192)
else:
block = gevent.os.tp_read(fd, 8192)
except (BlockingIOError, OSError):
yield ""
continue

if not block:
if buf:
yield buf.decode(enc)
buf.clear()
break

buf.extend(block)
Expand Down
27 changes: 22 additions & 5 deletions digits/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import mock
import PIL.Image
import numpy as np
import os
import platform

from . import image as _, errors

Expand Down Expand Up @@ -52,9 +54,19 @@ def check_good_file(self, args):
orig_mode, suffix, pixel, new_mode = args

orig = PIL.Image.new(orig_mode, (10,10), pixel)
with tempfile.NamedTemporaryFile(suffix='.' + suffix) as tmp:
orig.save(tmp.name)
new = _.load_image(tmp.name)

# temp files cause permission errors so just generate the name
tmp = tempfile.mkstemp(suffix='.' + suffix)
orig.save(tmp[1])
new = _.load_image(tmp[1])
try:
# sometimes on windows the file is not closed yet
# which can cause an exception
os.close(tmp[0])
os.remove(tmp[1])
except:
pass

assert new is not None, 'load_image should never return None'
assert new.mode == new_mode, 'Image mode should be "%s", not "%s\nargs - %s' % (new_mode, new.mode, args)

Expand Down Expand Up @@ -94,16 +106,21 @@ def test_corrupted_file(self):
corrupted = encoded[:size/2] + encoded[size/2:][::-1]

# Save the corrupted image to a temporary file.
f = tempfile.NamedTemporaryFile(delete=False)
fname = tempfile.mkstemp(suffix='.bin')
f = os.fdopen(fname[0],'wb')
fname = fname[1]

f.write(corrupted)
f.close()

assert_raises(
errors.LoadImageError,
_.load_image,
f.name,
fname,
)

os.remove(fname)

class TestResizeImage():

@classmethod
Expand Down
Loading