Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for Python3 #2184

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions __main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2014-2017, NVIDIA CORPORATION. All rights reserved.

import argparse
import os.path
import sys


# Update PATH to include the local DIGITS directory
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
found_parent_dir = False
for p in sys.path:
if os.path.abspath(p) == PARENT_DIR:
found_parent_dir = True
break
if not found_parent_dir:
sys.path.insert(0, PARENT_DIR)


def main():
parser = argparse.ArgumentParser(description='DIGITS server')
parser.add_argument(
'-p', '--port',
type=int,
default=5000,
help='Port to run app on (default 5000)'
)
parser.add_argument(
'-d', '--debug',
action='store_true',
help=('Run the application in debug mode (reloads when the source '
'changes and gives more detailed error messages)')
)
parser.add_argument(
'--version',
action='store_true',
help='Print the version number and exit'
)

args = vars(parser.parse_args())

import digits

if args['version']:
print(digits.__version__)
sys.exit()

print(' ___ ___ ___ ___ _____ ___')
print(' | \_ _/ __|_ _|_ _/ __|')
print(' | |) | | (_ || | | | \__ \\')
print(' |___/___\___|___| |_| |___/', digits.__version__)
print()

import digits.config
import digits.log
import digits.webapp

try:
if not digits.webapp.scheduler.start():
print('ERROR: Scheduler would not start')
else:
digits.webapp.app.debug = args['debug']
digits.webapp.socketio.run(digits.webapp.app, '0.0.0.0', args['port'])
except KeyboardInterrupt:
pass
finally:
digits.webapp.scheduler.stop()


if __name__ == '__main__':
main()
14 changes: 7 additions & 7 deletions digits/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ def main():
import digits

if args['version']:
print digits.__version__
print (digits.__version__)
sys.exit()

print ' ___ ___ ___ ___ _____ ___'
print ' | \_ _/ __|_ _|_ _/ __|'
print ' | |) | | (_ || | | | \__ \\'
print ' |___/___\___|___| |_| |___/', digits.__version__
print
print(' ___ ___ ___ ___ _____ ___')
print(' | \_ _/ __|_ _|_ _/ __|')
print(' | |) | | (_ || | | | \__ \\')
print(' |___/___\___|___| |_| |___/', digits.__version__)
print()

import digits.config
import digits.log
import digits.webapp

try:
if not digits.webapp.scheduler.start():
print 'ERROR: Scheduler would not start'
print('ERROR: Scheduler would not start')
else:
digits.webapp.app.debug = args['debug']
digits.webapp.socketio.run(digits.webapp.app, '0.0.0.0', args['port'])
Expand Down
2 changes: 2 additions & 0 deletions digits/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
url_prefix,
)

if option_list['caffe']['loaded'] is False and option_list['tensorflow']['enabled']:
option_list['caffe']['multi_gpu'] = True

def config_value(option):
"""
Expand Down
49 changes: 24 additions & 25 deletions digits/config/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ def load_from_envvar(envvar):
import_pycaffe(python_dir)
version, flavor = get_version_and_flavor(executable)
except:
print ('"%s" from %s does not point to a valid installation of Caffe.'
% (value, envvar))
print 'Use the envvar CAFFE_ROOT to indicate a valid installation.'
raise
raise ('"%s" from %s does not point to a valid installation of Caffe. \nUse the envvar CAFFE_ROOT to indicate a valid installation.'% (value, envvar))
return executable, version, flavor


Expand All @@ -57,9 +54,7 @@ def load_from_path():
import_pycaffe()
version, flavor = get_version_and_flavor(executable)
except:
print 'A valid Caffe installation was not found on your system.'
print 'Use the envvar CAFFE_ROOT to indicate a valid installation.'
raise
raise ('A valid Caffe installation was not found on your system. \nUse the envvar CAFFE_ROOT to indicate a valid installation.')
return executable, version, flavor


Expand Down Expand Up @@ -125,8 +120,7 @@ def import_pycaffe(dirname=None):
try:
import caffe
except ImportError:
print 'Did you forget to "make pycaffe"?'
raise
raise ('Did you forget to "make pycaffe"?')

# Strange issue with protocol buffers and pickle - see issue #32
sys.path.insert(0, os.path.join(
Expand Down Expand Up @@ -181,7 +175,7 @@ def get_version_from_cmdline(executable):
command = [executable, '-version']
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.wait():
print p.stderr.read().strip()
print(p.stderr.read().strip())
raise RuntimeError('"%s" returned error code %s' % (command, p.returncode))

pattern = 'version'
Expand All @@ -195,7 +189,7 @@ def get_version_from_soname(executable):
command = ['ldd', executable]
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.wait():
print p.stderr.read().strip()
print(p.stderr.read().strip())
raise RuntimeError('"%s" returned error code %s' % (command, p.returncode))

# Search output for caffe library
Expand All @@ -222,17 +216,22 @@ def get_version_from_soname(executable):
return None


if 'CAFFE_ROOT' in os.environ:
executable, version, flavor = load_from_envvar('CAFFE_ROOT')
elif 'CAFFE_HOME' in os.environ:
executable, version, flavor = load_from_envvar('CAFFE_HOME')
else:
executable, version, flavor = load_from_path()

option_list['caffe'] = {
'executable': executable,
'version': version,
'flavor': flavor,
'multi_gpu': (flavor == 'BVLC' or parse_version(version) >= parse_version(0, 12)),
'cuda_enabled': (len(device_query.get_devices()) > 0),
}
try:
if 'CAFFE_ROOT' in os.environ:
executable, version, flavor = load_from_envvar('CAFFE_ROOT')
elif 'CAFFE_HOME' in os.environ:
executable, version, flavor = load_from_envvar('CAFFE_HOME')
else:
executable, version, flavor = load_from_path()
option_list['caffe'] = {
'loaded': True,
'executable': executable,
'version': version,
'flavor': flavor,
'multi_gpu': (flavor == 'BVLC' or parse_version(version) >= parse_version(0, 12)),
'cuda_enabled': (len(device_query.get_devices()) > 0),
}
except (Exception, ImportError, RuntimeError) as e:
print("Caffe support disabled.")
# print("Reason: {}".format(e.message))
option_list['caffe'] = {'loaded': False, 'multi_gpu': False, 'cuda_enabled': False}
2 changes: 1 addition & 1 deletion digits/config/gpu_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
import digits.device_query


option_list['gpu_list'] = ','.join([str(x) for x in xrange(len(digits.device_query.get_devices()))])
option_list['gpu_list'] = ','.join([str(x) for x in range(len(digits.device_query.get_devices()))])
4 changes: 2 additions & 2 deletions digits/config/jobs_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
if not os.path.exists(value):
os.makedirs(value)
except:
print '"%s" is not a valid value for jobs_dir.' % value
print 'Set the envvar DIGITS_JOBS_DIR to fix your configuration.'
print('"%s" is not a valid value for jobs_dir.' % value)
print('Set the envvar DIGITS_JOBS_DIR to fix your configuration.')
raise


Expand Down
4 changes: 2 additions & 2 deletions digits/config/log_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def load_logfile_filename():
pass
except:
if throw_error:
print '"%s" is not a valid value for logfile_filename.' % filename
print 'Set the envvar DIGITS_LOGFILE_FILENAME to fix your configuration.'
print('"%s" is not a valid value for logfile_filename.' % filename)
print('Set the envvar DIGITS_LOGFILE_FILENAME to fix your configuration.')
raise
else:
filename = None
Expand Down
4 changes: 2 additions & 2 deletions digits/config/store_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import absolute_import

import os
from urlparse import urlparse
from urllib import parse

from . import option_list

Expand All @@ -14,7 +14,7 @@ def validate(value):
if isinstance(value, str):
url_list = value.split(',')
for url in url_list:
if url is not None and urlparse(url).scheme != "" and not os.path.exists(url):
if url is not None and parse.urlparse(url).scheme != "" and not os.path.exists(url):
valid_url_list.append(url)
else:
raise ValueError('"%s" is not a valid URL' % url)
Expand Down
2 changes: 1 addition & 1 deletion digits/config/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_tf_import():
try:
import tensorflow # noqa
return True
except (ImportError, TypeError):
except ImportError:
return False

tf_enabled = test_tf_import()
Expand Down
24 changes: 12 additions & 12 deletions digits/dataset/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def create_dataset(cls, **kwargs):
if rv.status_code != 200:
raise RuntimeError(
'Dataset creation failed with %s' % rv.status_code)
return json.loads(rv.data)['id']
return json.loads(rv.get_data(as_text=True))['id']

# expect a redirect
if not 300 <= rv.status_code <= 310:
s = BeautifulSoup(rv.data, 'html.parser')
div = s.select('div.alert-danger')
if div:
print div[0]
print(div[0])
else:
print rv.data
print(rv.get_data(as_text=True))
raise RuntimeError(
'Failed to create dataset - status %s' % rv.status_code)

Expand All @@ -112,7 +112,7 @@ def create_dataset(cls, **kwargs):
def get_dataset_json(cls):
rv = cls.app.get('/datasets/%s/json' % cls.dataset_id)
assert rv.status_code == 200, 'page load failed with %s' % rv.status_code
return json.loads(rv.data)
return json.loads(rv.get_data(as_text=True))

@classmethod
def get_entry_count(cls, stage):
Expand All @@ -138,7 +138,7 @@ def create_random_imageset(cls, **kwargs):
if not hasattr(cls, 'imageset_folder'):
# create a temporary folder
cls.imageset_folder = tempfile.mkdtemp()
for i in xrange(num_images):
for i in range(num_images):
x = np.random.randint(
low=0,
high=256,
Expand All @@ -162,7 +162,7 @@ def create_variable_size_random_imageset(cls, **kwargs):
if not hasattr(cls, 'imageset_folder'):
# create a temporary folder
cls.imageset_folder = tempfile.mkdtemp()
for i in xrange(num_images):
for i in range(num_images):
image_width = np.random.randint(low=8, high=32)
image_height = np.random.randint(low=8, high=32)
x = np.random.randint(
Expand Down Expand Up @@ -207,9 +207,9 @@ def setUpClass(cls, **kwargs):

def test_page_dataset_new(self):
rv = self.app.get('/datasets/generic/new/%s' % self.EXTENSION_ID)
print rv.data
print(rv.get_data(as_text=True))
assert rv.status_code == 200, 'page load failed with %s' % rv.status_code
assert extensions.data.get_extension(self.EXTENSION_ID).get_title() in rv.data, 'unexpected page format'
assert extensions.data.get_extension(self.EXTENSION_ID).get_title() in rv.get_data(as_text=True), 'unexpected page format'

def test_nonexistent_dataset(self):
assert not self.dataset_exists('foo'), "dataset shouldn't exist"
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_clone(self):
assert self.dataset_wait_completion(job1_id) == 'Done', 'first job failed'
rv = self.app.get('/datasets/%s/json' % job1_id)
assert rv.status_code == 200, 'json load failed with %s' % rv.status_code
content1 = json.loads(rv.data)
content1 = json.loads(rv.get_data(as_text=True))

# Clone job1 as job2
options_2 = {
Expand All @@ -275,7 +275,7 @@ def test_clone(self):
assert self.dataset_wait_completion(job2_id) == 'Done', 'second job failed'
rv = self.app.get('/datasets/%s/json' % job2_id)
assert rv.status_code == 200, 'json load failed with %s' % rv.status_code
content2 = json.loads(rv.data)
content2 = json.loads(rv.get_data(as_text=True))

# These will be different
content1.pop('id')
Expand All @@ -298,7 +298,7 @@ class GenericCreatedTest(BaseViewsTestWithDataset):
def test_index_json(self):
rv = self.app.get('/index/json')
assert rv.status_code == 200, 'page load failed with %s' % rv.status_code
content = json.loads(rv.data)
content = json.loads(rv.get_data(as_text=True))
found = False
for d in content['datasets']:
if d['id'] == self.dataset_id:
Expand All @@ -318,7 +318,7 @@ def test_edit_name(self):
assert status == 200, 'failed with %s' % status
rv = self.app.get('/datasets/summary?job_id=%s' % self.dataset_id)
assert rv.status_code == 200
assert 'new name' in rv.data
assert 'new name' in rv.get_data(as_text=True)

def test_edit_notes(self):
status = self.edit_job(
Expand Down
14 changes: 8 additions & 6 deletions digits/dataset/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO

import caffe_pb2
from io import StringIO
from io import BytesIO
# import caffe_pb2
import flask
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import base64

from .forms import GenericDatasetForm
from .job import GenericDatasetJob
from digits import extensions, utils
from digits.dataset import dataset_pb2
from digits.utils.constants import COLOR_PALETTE_ATTRIBUTE
from digits.utils.routing import request_wants_json, job_from_request
from digits.utils.lmdbreader import DbReader
Expand Down Expand Up @@ -172,16 +174,16 @@ def explore():
min_page = max(0, page - 5)
total_entries = reader.total_entries

max_page = min((total_entries - 1) / size, page + 5)
max_page = min((total_entries - 1) // size, page + 5)
pages = range(min_page, max_page + 1)
for key, value in reader.entries():
if count >= page * size:
datum = caffe_pb2.Datum()
datum = dataset_pb2.Datum()
datum.ParseFromString(value)
if not datum.encoded:
raise RuntimeError("Expected encoded database")
s = StringIO()
s.write(datum.data)
s.write(datum.data.decode())
s.seek(0)
img = PIL.Image.open(s)
if cmap and img.mode in ['L', '1']:
Expand Down
Loading