diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000..8daa7cd4 --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,4 @@ +# File : .pep8speaks.yml + +flake8: + max-line-length: 88 diff --git a/setup.py b/setup.py index 47553eb7..dc1265c9 100755 --- a/setup.py +++ b/setup.py @@ -1,15 +1,15 @@ -'''Web server Tableau uses to run Python scripts. +"""Web server Tableau uses to run Python scripts. TabPy (the Tableau Python Server) is an external service implementation which expands Tableau's capabilities by allowing users to execute Python scripts and saved functions via Tableau's table calculations. -''' +""" import os from setuptools import setup, find_packages -DOCLINES = (__doc__ or '').split('\n') +DOCLINES = (__doc__ or "").split("\n") def setup_package(): @@ -17,89 +17,88 @@ def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() setup( - name='tabpy', - version=read('tabpy/VERSION'), + name="tabpy", + version=read("tabpy/VERSION"), description=DOCLINES[0], - long_description='\n'.join(DOCLINES[1:]) + '\n' + read('CHANGELOG'), - long_description_content_type='text/markdown', - url='https://github.com/tableau/TabPy', - author='Tableau', - author_email='github@tableau.com', - maintainer='Tableau', - maintainer_email='github@tableau.com', - download_url='https://pypi.org/project/tabpy', + long_description="\n".join(DOCLINES[1:]) + "\n" + read("CHANGELOG"), + long_description_content_type="text/markdown", + url="https://github.com/tableau/TabPy", + author="Tableau", + author_email="github@tableau.com", + maintainer="Tableau", + maintainer_email="github@tableau.com", + download_url="https://pypi.org/project/tabpy", project_urls={ "Bug Tracker": "https://github.com/tableau/TabPy/issues", "Documentation": "https://tableau.github.io/TabPy/", "Source Code": "https://github.com/tableau/TabPy", }, classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Information Analysis', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS' + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Information Analysis", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", ], - platforms=['Windows', 'Linux', 'Mac OS-X', 'Unix'], - keywords=['tabpy tableau'], - packages=find_packages( - exclude=['docs', 'misc', 'tests']), + platforms=["Windows", "Linux", "Mac OS-X", "Unix"], + keywords=["tabpy tableau"], + packages=find_packages(exclude=["docs", "misc", "tests"]), package_data={ - 'tabpy': [ - 'VERSION', - 'tabpy_server/state.ini.template', - 'tabpy_server/static', - 'tabpy_server/common/default.conf' + "tabpy": [ + "VERSION", + "tabpy_server/state.ini.template", + "tabpy_server/static", + "tabpy_server/common/default.conf", ] }, - python_requires='>=3.6', - license='MIT', + python_requires=">=3.6", + license="MIT", # Note: many of these required packages are included in base python # but are listed here because different linux distros use custom # python installations. And users can remove packages at any point install_requires=[ - 'backports_abc', - 'cloudpickle', - 'configparser', - 'decorator', - 'future', - 'genson', - 'jsonschema', - 'pyopenssl', - 'python-dateutil', - 'requests', - 'singledispatch', - 'six', - 'tornado', - 'urllib3<1.25,>=1.21.1' + "backports_abc", + "cloudpickle", + "configparser", + "decorator", + "future", + "genson", + "jsonschema", + "pyopenssl", + "python-dateutil", + "requests", + "singledispatch", + "six", + "tornado", + "urllib3<1.25,>=1.21.1", ], entry_points={ - 'console_scripts': [ - 'tabpy=tabpy.tabpy:main', - 'tabpy-deploy-models=tabpy.models.deploy_models:main', - 'tabpy-user-management=tabpy.utils.user_management:main' + "console_scripts": [ + "tabpy=tabpy.tabpy:main", + "tabpy-deploy-models=tabpy.models.deploy_models:main", + "tabpy-user-management=tabpy.utils.user_management:main", ], }, - setup_requires=['pytest-runner'], + setup_requires=["pytest-runner"], tests_require=[ - 'mock', - 'nltk', - 'numpy', - 'pandas', - 'pytest', - 'scipy', - 'sklearn', - 'textblob' + "mock", + "nltk", + "numpy", + "pandas", + "pytest", + "scipy", + "sklearn", + "textblob", ], - test_suite='pytest' + test_suite="pytest", ) -if __name__ == '__main__': +if __name__ == "__main__": setup_package() diff --git a/tabpy/models/deploy_models.py b/tabpy/models/deploy_models.py index a2d6b63c..a026afb8 100644 --- a/tabpy/models/deploy_models.py +++ b/tabpy/models/deploy_models.py @@ -14,40 +14,38 @@ def install_dependencies(packages): - pip_arg = ['install'] + packages + ['--no-cache-dir'] - if hasattr(pip, 'main'): + pip_arg = ["install"] + packages + ["--no-cache-dir"] + if hasattr(pip, "main"): pip.main(pip_arg) else: pip._internal.main(pip_arg) def main(): - install_dependencies(['sklearn', 'pandas', 'numpy', - 'textblob', 'nltk', 'scipy']) - print('==================================================================') + install_dependencies(["sklearn", "pandas", "numpy", "textblob", "nltk", "scipy"]) + print("==================================================================") # Determine if we run python or python3 - if platform.system() == 'Windows': - py = 'python' + if platform.system() == "Windows": + py = "python" else: - py = 'python3' + py = "python3" if len(sys.argv) > 1: config_file_path = sys.argv[1] else: config_file_path = setup_utils.get_default_config_file_path() - print(f'Using config file at {config_file_path}') + print(f"Using config file at {config_file_path}") port, auth_on, prefix = setup_utils.parse_config(config_file_path) if auth_on: auth_args = setup_utils.get_creds() else: auth_args = [] - directory = str(Path(__file__).resolve().parent / 'scripts') + directory = str(Path(__file__).resolve().parent / "scripts") # Deploy each model in the scripts directory for filename in os.listdir(directory): - subprocess.run([py, f'{directory}/{filename}', config_file_path] - + auth_args) + subprocess.run([py, f"{directory}/{filename}", config_file_path] + auth_args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tabpy/models/scripts/ANOVA.py b/tabpy/models/scripts/ANOVA.py index b151b086..4cf90c2b 100644 --- a/tabpy/models/scripts/ANOVA.py +++ b/tabpy/models/scripts/ANOVA.py @@ -3,11 +3,11 @@ def anova(_arg1, _arg2, *_argN): - ''' + """ ANOVA is a statistical hypothesis test that is used to compare two or more group means for equality.For more information on the function and how to use it please refer to tabpy-tools.md - ''' + """ cols = [_arg1, _arg2] + list(_argN) for col in cols: @@ -18,8 +18,5 @@ def anova(_arg1, _arg2, *_argN): return p_value -if __name__ == '__main__': - setup_utils.deploy_model( - 'anova', - anova, - 'Returns the p-value form an ANOVA test') +if __name__ == "__main__": + setup_utils.deploy_model("anova", anova, "Returns the p-value form an ANOVA test") diff --git a/tabpy/models/scripts/PCA.py b/tabpy/models/scripts/PCA.py index f9f5f492..df23632c 100644 --- a/tabpy/models/scripts/PCA.py +++ b/tabpy/models/scripts/PCA.py @@ -8,16 +8,16 @@ def PCA(component, _arg1, _arg2, *_argN): - ''' + """ Principal Component Analysis is a technique that extracts the key distinct components from a high dimensional space whie attempting to capture as much of the variance as possible. For more information on the function and how to use it please refer to tabpy-tools.md - ''' + """ cols = [_arg1, _arg2] + list(_argN) encodedCols = [] labelEncoder = LabelEncoder() - oneHotEncoder = OneHotEncoder(categories='auto', sparse=False) + oneHotEncoder = OneHotEncoder(categories="auto", sparse=False) for col in cols: if isinstance(col[0], (int, float)): @@ -27,8 +27,10 @@ def PCA(component, _arg1, _arg2, *_argN): encodedCols.append(intCol.astype(int)) else: if len(set(col)) > 25: - print('ERROR: Non-numeric arguments cannot have more than ' - '25 unique values') + print( + "ERROR: Non-numeric arguments cannot have more than " + "25 unique values" + ) raise ValueError integerEncoded = labelEncoder.fit_transform(array(col)) integerEncoded = integerEncoded.reshape(len(col), 1) @@ -38,11 +40,10 @@ def PCA(component, _arg1, _arg2, *_argN): dataDict = {} for i in range(len(encodedCols)): - dataDict[f'col{1 + i}'] = list(encodedCols[i]) + dataDict[f"col{1 + i}"] = list(encodedCols[i]) if component <= 0 or component > len(dataDict): - print('ERROR: Component specified must be >= 0 and ' - '<= number of arguments') + print("ERROR: Component specified must be >= 0 and " "<= number of arguments") raise ValueError df = pd.DataFrame(data=dataDict, dtype=float) @@ -55,8 +56,5 @@ def PCA(component, _arg1, _arg2, *_argN): return pcaComponents[:, component - 1].tolist() -if __name__ == '__main__': - setup_utils.deploy_model( - 'PCA', - PCA, - 'Returns the specified principal component') +if __name__ == "__main__": + setup_utils.deploy_model("PCA", PCA, "Returns the specified principal component") diff --git a/tabpy/models/scripts/SentimentAnalysis.py b/tabpy/models/scripts/SentimentAnalysis.py index 4a978b9f..ed4e0c7e 100644 --- a/tabpy/models/scripts/SentimentAnalysis.py +++ b/tabpy/models/scripts/SentimentAnalysis.py @@ -5,47 +5,48 @@ import ssl + _ctx = ssl._create_unverified_context ssl._create_default_https_context = _ctx -nltk.download('vader_lexicon') -nltk.download('punkt') +nltk.download("vader_lexicon") +nltk.download("punkt") -def SentimentAnalysis(_arg1, library='nltk'): - ''' +def SentimentAnalysis(_arg1, library="nltk"): + """ Sentiment Analysis is a procedure that assigns a score from -1 to 1 for a piece of text with -1 being negative and 1 being positive. For more information on the function and how to use it please refer to tabpy-tools.md - ''' + """ if not (isinstance(_arg1[0], str)): raise TypeError - supportedLibraries = {'nltk', 'textblob'} + supportedLibraries = {"nltk", "textblob"} library = library.lower() if library not in supportedLibraries: raise ValueError scores = [] - if library == 'nltk': + if library == "nltk": sid = SentimentIntensityAnalyzer() for text in _arg1: sentimentResults = sid.polarity_scores(text) - score = sentimentResults['compound'] + score = sentimentResults["compound"] scores.append(score) - elif library == 'textblob': + elif library == "textblob": for text in _arg1: currScore = TextBlob(text) scores.append(currScore.sentiment.polarity) return scores -if __name__ == '__main__': +if __name__ == "__main__": setup_utils.deploy_model( - 'Sentiment Analysis', + "Sentiment Analysis", SentimentAnalysis, - 'Returns a sentiment score between -1 and 1 for ' - 'a given string') + "Returns a sentiment score between -1 and 1 for " "a given string", + ) diff --git a/tabpy/models/scripts/tTest.py b/tabpy/models/scripts/tTest.py index 8fffee37..433a4750 100644 --- a/tabpy/models/scripts/tTest.py +++ b/tabpy/models/scripts/tTest.py @@ -3,12 +3,12 @@ def ttest(_arg1, _arg2): - ''' + """ T-Test is a statistical hypothesis test that is used to compare two sample means or a sample’s mean against a known population mean. For more information on the function and how to use it please refer to tabpy-tools.md - ''' + """ # one sample test with mean if len(_arg2) == 1: test_stat, p_value = stats.ttest_1samp(_arg1, _arg2) @@ -35,8 +35,5 @@ def ttest(_arg1, _arg2): return p_value -if __name__ == '__main__': - setup_utils.deploy_model( - 'ttest', - ttest, - 'Returns the p-value form a t-test') +if __name__ == "__main__": + setup_utils.deploy_model("ttest", ttest, "Returns the p-value form a t-test") diff --git a/tabpy/models/utils/setup_utils.py b/tabpy/models/utils/setup_utils.py index 65468153..d1876714 100644 --- a/tabpy/models/utils/setup_utils.py +++ b/tabpy/models/utils/setup_utils.py @@ -8,25 +8,27 @@ def get_default_config_file_path(): import tabpy + pkg_path = os.path.dirname(tabpy.__file__) - config_file_path = os.path.join( - pkg_path, 'tabpy_server', 'common', 'default.conf') + config_file_path = os.path.join(pkg_path, "tabpy_server", "common", "default.conf") return config_file_path def parse_config(config_file_path): config = configparser.ConfigParser() config.read(config_file_path) - tabpy_config = config['TabPy'] + tabpy_config = config["TabPy"] port = 9004 - if 'TABPY_PORT' in tabpy_config: - port = tabpy_config['TABPY_PORT'] + if "TABPY_PORT" in tabpy_config: + port = tabpy_config["TABPY_PORT"] - auth_on = 'TABPY_PWD_FILE' in tabpy_config - ssl_on = 'TABPY_TRANSFER_PROTOCOL' in tabpy_config and \ - 'TABPY_CERTIFICATE_FILE' in tabpy_config and \ - 'TABPY_KEY_FILE' in tabpy_config + auth_on = "TABPY_PWD_FILE" in tabpy_config + ssl_on = ( + "TABPY_TRANSFER_PROTOCOL" in tabpy_config + and "TABPY_CERTIFICATE_FILE" in tabpy_config + and "TABPY_KEY_FILE" in tabpy_config + ) prefix = "https" if ssl_on else "http" return port, auth_on, prefix @@ -49,7 +51,7 @@ def deploy_model(funcName, func, funcDescription): config_file_path = get_default_config_file_path() port, auth_on, prefix = parse_config(config_file_path) - connection = Client(f'{prefix}://localhost:{port}/') + connection = Client(f"{prefix}://localhost:{port}/") if auth_on: # credentials are passed in from setup.py @@ -61,4 +63,4 @@ def deploy_model(funcName, func, funcDescription): connection.set_credentials(user, passwd) connection.deploy(funcName, func, funcDescription, override=True) - print(f'Successfully deployed {funcName}') + print(f"Successfully deployed {funcName}") diff --git a/tabpy/tabpy.py b/tabpy/tabpy.py index b03f16cf..3ee0df84 100755 --- a/tabpy/tabpy.py +++ b/tabpy/tabpy.py @@ -1,24 +1,25 @@ -''' +""" TabPy application. This file main() function is an entry point for 'tabpy' command. -''' +""" import os from pathlib import Path def read_version(): - ver = 'unknown' + ver = "unknown" import tabpy + pkg_path = os.path.dirname(tabpy.__file__) - ver_file_path = os.path.join(pkg_path, 'VERSION') + ver_file_path = os.path.join(pkg_path, "VERSION") if Path(ver_file_path).exists(): with open(ver_file_path) as f: ver = f.read().strip() else: - ver = f'Version Unknown, (file {ver_file_path} not found)' + ver = f"Version Unknown, (file {ver_file_path} not found)" return ver @@ -28,9 +29,10 @@ def read_version(): def main(): from tabpy.tabpy_server.app.app import TabPyApp + app = TabPyApp() app.run() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tabpy/tabpy_server/app/ConfigParameters.py b/tabpy/tabpy_server/app/ConfigParameters.py index 14abbca2..3feca9b7 100644 --- a/tabpy/tabpy_server/app/ConfigParameters.py +++ b/tabpy/tabpy_server/app/ConfigParameters.py @@ -1,16 +1,17 @@ class ConfigParameters: - ''' + """ Configuration settings names - ''' - TABPY_PWD_FILE = 'TABPY_PWD_FILE' - TABPY_PORT = 'TABPY_PORT' - TABPY_QUERY_OBJECT_PATH = 'TABPY_QUERY_OBJECT_PATH' - TABPY_STATE_PATH = 'TABPY_STATE_PATH' - TABPY_TRANSFER_PROTOCOL = 'TABPY_TRANSFER_PROTOCOL' - TABPY_CERTIFICATE_FILE = 'TABPY_CERTIFICATE_FILE' - TABPY_KEY_FILE = 'TABPY_KEY_FILE' - TABPY_PWD_FILE = 'TABPY_PWD_FILE' - TABPY_LOG_DETAILS = 'TABPY_LOG_DETAILS' - TABPY_STATIC_PATH = 'TABPY_STATIC_PATH' - TABPY_MAX_REQUEST_SIZE_MB = 'TABPY_MAX_REQUEST_SIZE_MB' - TABPY_EVALUATE_TIMEOUT = 'TABPY_EVALUATE_TIMEOUT' + """ + + TABPY_PWD_FILE = "TABPY_PWD_FILE" + TABPY_PORT = "TABPY_PORT" + TABPY_QUERY_OBJECT_PATH = "TABPY_QUERY_OBJECT_PATH" + TABPY_STATE_PATH = "TABPY_STATE_PATH" + TABPY_TRANSFER_PROTOCOL = "TABPY_TRANSFER_PROTOCOL" + TABPY_CERTIFICATE_FILE = "TABPY_CERTIFICATE_FILE" + TABPY_KEY_FILE = "TABPY_KEY_FILE" + TABPY_PWD_FILE = "TABPY_PWD_FILE" + TABPY_LOG_DETAILS = "TABPY_LOG_DETAILS" + TABPY_STATIC_PATH = "TABPY_STATIC_PATH" + TABPY_MAX_REQUEST_SIZE_MB = "TABPY_MAX_REQUEST_SIZE_MB" + TABPY_EVALUATE_TIMEOUT = "TABPY_EVALUATE_TIMEOUT" diff --git a/tabpy/tabpy_server/app/SettingsParameters.py b/tabpy/tabpy_server/app/SettingsParameters.py index a455fdaa..45fb128a 100755 --- a/tabpy/tabpy_server/app/SettingsParameters.py +++ b/tabpy/tabpy_server/app/SettingsParameters.py @@ -1,16 +1,17 @@ class SettingsParameters: - ''' + """ Application (TabPyApp) settings names - ''' - TransferProtocol = 'transfer_protocol' - Port = 'port' - ServerVersion = 'server_version' - UploadDir = 'upload_dir' - CertificateFile = 'certificate_file' - KeyFile = 'key_file' - StateFilePath = 'state_file_path' - ApiVersions = 'versions' - LogRequestContext = 'log_request_context' - StaticPath = 'static_path' - MaxRequestSizeInMb = 'max_request_size_in_mb' - EvaluateTimeout = 'evaluate_timeout' + """ + + TransferProtocol = "transfer_protocol" + Port = "port" + ServerVersion = "server_version" + UploadDir = "upload_dir" + CertificateFile = "certificate_file" + KeyFile = "key_file" + StateFilePath = "state_file_path" + ApiVersions = "versions" + LogRequestContext = "log_request_context" + StaticPath = "static_path" + MaxRequestSizeInMb = "max_request_size_in_mb" + EvaluateTimeout = "evaluate_timeout" diff --git a/tabpy/tabpy_server/app/app.py b/tabpy/tabpy_server/app/app.py index 2df69620..341a2f67 100644 --- a/tabpy/tabpy_server/app/app.py +++ b/tabpy/tabpy_server/app/app.py @@ -14,15 +14,17 @@ from tabpy.tabpy_server.app.util import parse_pwd_file from tabpy.tabpy_server.management.state import TabPyState from tabpy.tabpy_server.management.util import _get_state_from_file -from tabpy.tabpy_server.psws.callbacks\ - import (init_model_evaluator, init_ps_server) -from tabpy.tabpy_server.psws.python_service\ - import (PythonService, PythonServiceHandler) -from tabpy.tabpy_server.handlers\ - import (EndpointHandler, EndpointsHandler, - EvaluationPlaneHandler, QueryPlaneHandler, - ServiceInfoHandler, StatusHandler, - UploadDestinationHandler) +from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server +from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler +from tabpy.tabpy_server.handlers import ( + EndpointHandler, + EndpointsHandler, + EvaluationPlaneHandler, + QueryPlaneHandler, + ServiceInfoHandler, + StatusHandler, + UploadDestinationHandler, +) import tornado @@ -30,9 +32,9 @@ class TabPyApp: - ''' + """ TabPy application class for keeping context like settings, state, etc. - ''' + """ settings = {} subdirectory = "" @@ -46,14 +48,13 @@ def __init__(self, config_file=None): if cli_args.config is not None: config_file = cli_args.config else: - config_file = os.path.join(os.path.dirname(__file__), - os.path.pardir, 'common', - 'default.conf') + config_file = os.path.join( + os.path.dirname(__file__), os.path.pardir, "common", "default.conf" + ) if os.path.isfile(config_file): try: - logging.config.fileConfig( - config_file, disable_existing_loggers=False) + logging.config.fileConfig(config_file, disable_existing_loggers=False) except KeyError: logging.basicConfig(level=logging.DEBUG) @@ -61,25 +62,22 @@ def __init__(self, config_file=None): def run(self): application = self._create_tornado_web_app() - max_request_size =\ - int(self.settings[SettingsParameters.MaxRequestSizeInMb]) *\ - 1024 * 1024 - logger.info(f'Setting max request size to {max_request_size} bytes') + max_request_size = ( + int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024 + ) + logger.info(f"Setting max request size to {max_request_size} bytes") - init_model_evaluator( - self.settings, - self.tabpy_state, - self.python_service) + init_model_evaluator(self.settings, self.tabpy_state, self.python_service) protocol = self.settings[SettingsParameters.TransferProtocol] ssl_options = None - if protocol == 'https': + if protocol == "https": ssl_options = { - 'certfile': self.settings[SettingsParameters.CertificateFile], - 'keyfile': self.settings[SettingsParameters.KeyFile] + "certfile": self.settings[SettingsParameters.CertificateFile], + "keyfile": self.settings[SettingsParameters.KeyFile], } - elif protocol != 'http': - msg = f'Unsupported transfer protocol {protocol}.' + elif protocol != "http": + msg = f"Unsupported transfer protocol {protocol}." logger.critical(msg) raise RuntimeError(msg) @@ -87,11 +85,13 @@ def run(self): self.settings[SettingsParameters.Port], ssl_options=ssl_options, max_buffer_size=max_request_size, - max_body_size=max_request_size) + max_body_size=max_request_size, + ) logger.info( - 'Web service listening on port ' - f'{str(self.settings[SettingsParameters.Port])}') + "Web service listening on port " + f"{str(self.settings[SettingsParameters.Port])}" + ) tornado.ioloop.IOLoop.instance().start() def _create_tornado_web_app(self): @@ -99,48 +99,65 @@ class TabPyTornadoApp(tornado.web.Application): is_closing = False def signal_handler(self, signal): - logger.critical(f'Exiting on signal {signal}...') + logger.critical(f"Exiting on signal {signal}...") self.is_closing = True def try_exit(self): if self.is_closing: tornado.ioloop.IOLoop.instance().stop() - logger.info('Shutting down TabPy...') + logger.info("Shutting down TabPy...") - logger.info('Initializing TabPy...') + logger.info("Initializing TabPy...") tornado.ioloop.IOLoop.instance().run_sync( - lambda: init_ps_server(self.settings, self.tabpy_state)) - logger.info('Done initializing TabPy.') + lambda: init_ps_server(self.settings, self.tabpy_state) + ) + logger.info("Done initializing TabPy.") executor = concurrent.futures.ThreadPoolExecutor( - max_workers=multiprocessing.cpu_count()) + max_workers=multiprocessing.cpu_count() + ) # initialize Tornado application - application = TabPyTornadoApp([ - # skip MainHandler to use StaticFileHandler .* page requests and - # default to index.html - # (r"/", MainHandler), - (self.subdirectory + r'/query/([^/]+)', QueryPlaneHandler, - dict(app=self)), - (self.subdirectory + r'/status', StatusHandler, - dict(app=self)), - (self.subdirectory + r'/info', ServiceInfoHandler, - dict(app=self)), - (self.subdirectory + r'/endpoints', EndpointsHandler, - dict(app=self)), - (self.subdirectory + r'/endpoints/([^/]+)?', EndpointHandler, - dict(app=self)), - (self.subdirectory + r'/evaluate', EvaluationPlaneHandler, - dict(executor=executor, - app=self)), - (self.subdirectory + - r'/configurations/endpoint_upload_destination', - UploadDestinationHandler, - dict(app=self)), - (self.subdirectory + r'/(.*)', tornado.web.StaticFileHandler, - dict(path=self.settings[SettingsParameters.StaticPath], - default_filename="index.html")), - ], debug=False, **self.settings) + application = TabPyTornadoApp( + [ + # skip MainHandler to use StaticFileHandler .* page requests and + # default to index.html + # (r"/", MainHandler), + ( + self.subdirectory + r"/query/([^/]+)", + QueryPlaneHandler, + dict(app=self), + ), + (self.subdirectory + r"/status", StatusHandler, dict(app=self)), + (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)), + (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)), + ( + self.subdirectory + r"/endpoints/([^/]+)?", + EndpointHandler, + dict(app=self), + ), + ( + self.subdirectory + r"/evaluate", + EvaluationPlaneHandler, + dict(executor=executor, app=self), + ), + ( + self.subdirectory + r"/configurations/endpoint_upload_destination", + UploadDestinationHandler, + dict(app=self), + ), + ( + self.subdirectory + r"/(.*)", + tornado.web.StaticFileHandler, + dict( + path=self.settings[SettingsParameters.StaticPath], + default_filename="index.html", + ), + ), + ], + debug=False, + **self.settings, + ) signal.signal(signal.SIGINT, application.signal_handler) tornado.ioloop.PeriodicCallback(application.try_exit, 500).start() @@ -148,12 +165,12 @@ def try_exit(self): return application def _parse_cli_arguments(self): - ''' + """ Parse command line arguments. Expected arguments: * --config: string - ''' - parser = ArgumentParser(description='Run TabPy Server.') - parser.add_argument('--config', help='Path to a config file.') + """ + parser = ArgumentParser(description="Run TabPy Server.") + parser.add_argument("--config", help="Path to a config file.") return parser.parse_args() def _parse_config(self, config_file): @@ -190,187 +207,213 @@ def _parse_config(self, config_file): parser.read_string(f.read()) else: logger.warning( - f'Unable to find config file at {config_file}, ' - 'using default settings.') + f"Unable to find config file at {config_file}, " + "using default settings." + ) - def set_parameter(settings_key, - config_key, - default_val=None): + def set_parameter(settings_key, config_key, default_val=None): key_is_set = False - if config_key is not None and\ - parser.has_section('TabPy') and\ - parser.has_option('TabPy', config_key): - self.settings[settings_key] = parser.get('TabPy', config_key) + if ( + config_key is not None + and parser.has_section("TabPy") + and parser.has_option("TabPy", config_key) + ): + self.settings[settings_key] = parser.get("TabPy", config_key) key_is_set = True logger.debug( - f'Parameter {settings_key} set to ' + f"Parameter {settings_key} set to " f'"{self.settings[settings_key]}" ' - 'from config file or environment variable') + "from config file or environment variable" + ) if not key_is_set and default_val is not None: self.settings[settings_key] = default_val key_is_set = True logger.debug( - f'Parameter {settings_key} set to ' + f"Parameter {settings_key} set to " f'"{self.settings[settings_key]}" ' - 'from default value') + "from default value" + ) if not key_is_set: - logger.debug( - f'Parameter {settings_key} is not set') + logger.debug(f"Parameter {settings_key} is not set") - set_parameter(SettingsParameters.Port, ConfigParameters.TABPY_PORT, - default_val=9004) - set_parameter(SettingsParameters.ServerVersion, None, - default_val=__version__) + set_parameter( + SettingsParameters.Port, ConfigParameters.TABPY_PORT, default_val=9004 + ) + set_parameter(SettingsParameters.ServerVersion, None, default_val=__version__) - set_parameter(SettingsParameters.EvaluateTimeout, - ConfigParameters.TABPY_EVALUATE_TIMEOUT, - default_val=30) + set_parameter( + SettingsParameters.EvaluateTimeout, + ConfigParameters.TABPY_EVALUATE_TIMEOUT, + default_val=30, + ) try: self.settings[SettingsParameters.EvaluateTimeout] = float( - self.settings[SettingsParameters.EvaluateTimeout]) + self.settings[SettingsParameters.EvaluateTimeout] + ) except ValueError: logger.warning( - 'Evaluate timeout must be a float type. Defaulting ' - 'to evaluate timeout of 30 seconds.') + "Evaluate timeout must be a float type. Defaulting " + "to evaluate timeout of 30 seconds." + ) self.settings[SettingsParameters.EvaluateTimeout] = 30 pkg_path = os.path.dirname(tabpy.__file__) - set_parameter(SettingsParameters.UploadDir, - ConfigParameters.TABPY_QUERY_OBJECT_PATH, - default_val=os.path.join(pkg_path, - 'tmp', 'query_objects')) + set_parameter( + SettingsParameters.UploadDir, + ConfigParameters.TABPY_QUERY_OBJECT_PATH, + default_val=os.path.join(pkg_path, "tmp", "query_objects"), + ) if not os.path.exists(self.settings[SettingsParameters.UploadDir]): os.makedirs(self.settings[SettingsParameters.UploadDir]) # set and validate transfer protocol - set_parameter(SettingsParameters.TransferProtocol, - ConfigParameters.TABPY_TRANSFER_PROTOCOL, - default_val='http') - self.settings[SettingsParameters.TransferProtocol] =\ - self.settings[SettingsParameters.TransferProtocol].lower() - - set_parameter(SettingsParameters.CertificateFile, - ConfigParameters.TABPY_CERTIFICATE_FILE) - set_parameter(SettingsParameters.KeyFile, - ConfigParameters.TABPY_KEY_FILE) + set_parameter( + SettingsParameters.TransferProtocol, + ConfigParameters.TABPY_TRANSFER_PROTOCOL, + default_val="http", + ) + self.settings[SettingsParameters.TransferProtocol] = self.settings[ + SettingsParameters.TransferProtocol + ].lower() + + set_parameter( + SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE + ) + set_parameter(SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE) self._validate_transfer_protocol_settings() # if state.ini does not exist try and create it - remove # last dependence on batch/shell script - set_parameter(SettingsParameters.StateFilePath, - ConfigParameters.TABPY_STATE_PATH, - default_val=os.path.join(pkg_path, 'tabpy_server')) + set_parameter( + SettingsParameters.StateFilePath, + ConfigParameters.TABPY_STATE_PATH, + default_val=os.path.join(pkg_path, "tabpy_server"), + ) self.settings[SettingsParameters.StateFilePath] = os.path.realpath( os.path.normpath( - os.path.expanduser( - self.settings[SettingsParameters.StateFilePath]))) + os.path.expanduser(self.settings[SettingsParameters.StateFilePath]) + ) + ) state_file_dir = self.settings[SettingsParameters.StateFilePath] - state_file_path = os.path.join(state_file_dir, 'state.ini') + state_file_path = os.path.join(state_file_dir, "state.ini") if not os.path.isfile(state_file_path): state_file_template_path = os.path.join( - pkg_path, 'tabpy_server', 'state.ini.template') - logger.debug(f'File {state_file_path} not found, creating from ' - f'template {state_file_template_path}...') + pkg_path, "tabpy_server", "state.ini.template" + ) + logger.debug( + f"File {state_file_path} not found, creating from " + f"template {state_file_template_path}..." + ) shutil.copy(state_file_template_path, state_file_path) - logger.info(f'Loading state from state file {state_file_path}') + logger.info(f"Loading state from state file {state_file_path}") tabpy_state = _get_state_from_file(state_file_dir) - self.tabpy_state = TabPyState( - config=tabpy_state, settings=self.settings) + self.tabpy_state = TabPyState(config=tabpy_state, settings=self.settings) self.python_service = PythonServiceHandler(PythonService()) - self.settings['compress_response'] = True - set_parameter(SettingsParameters.StaticPath, - ConfigParameters.TABPY_STATIC_PATH, - default_val='./') - self.settings[SettingsParameters.StaticPath] =\ - os.path.abspath(self.settings[SettingsParameters.StaticPath]) - logger.debug(f'Static pages folder set to ' - f'"{self.settings[SettingsParameters.StaticPath]}"') + self.settings["compress_response"] = True + set_parameter( + SettingsParameters.StaticPath, + ConfigParameters.TABPY_STATIC_PATH, + default_val="./", + ) + self.settings[SettingsParameters.StaticPath] = os.path.abspath( + self.settings[SettingsParameters.StaticPath] + ) + logger.debug( + f"Static pages folder set to " + f'"{self.settings[SettingsParameters.StaticPath]}"' + ) # Set subdirectory from config if applicable if tabpy_state.has_option("Service Info", "Subdirectory"): - self.subdirectory = "/" + \ - tabpy_state.get("Service Info", "Subdirectory") + self.subdirectory = "/" + tabpy_state.get("Service Info", "Subdirectory") # If passwords file specified load credentials - set_parameter(ConfigParameters.TABPY_PWD_FILE, - ConfigParameters.TABPY_PWD_FILE) + set_parameter(ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE) if ConfigParameters.TABPY_PWD_FILE in self.settings: if not self._parse_pwd_file(): - msg = ('Failed to read passwords file ' - f'{self.settings[ConfigParameters.TABPY_PWD_FILE]}') + msg = ( + "Failed to read passwords file " + f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}" + ) logger.critical(msg) raise RuntimeError(msg) else: logger.info( - "Password file is not specified: " - "Authentication is not enabled") + "Password file is not specified: " "Authentication is not enabled" + ) features = self._get_features() - self.settings[SettingsParameters.ApiVersions] =\ - {'v1': {'features': features}} + self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}} - set_parameter(SettingsParameters.LogRequestContext, - ConfigParameters.TABPY_LOG_DETAILS, - default_val='false') + set_parameter( + SettingsParameters.LogRequestContext, + ConfigParameters.TABPY_LOG_DETAILS, + default_val="false", + ) self.settings[SettingsParameters.LogRequestContext] = ( - self.settings[SettingsParameters.LogRequestContext].lower() != - 'false') - call_context_state =\ - 'enabled' if self.settings[SettingsParameters.LogRequestContext]\ - else 'disabled' - logger.info(f'Call context logging is {call_context_state}') - - set_parameter(SettingsParameters.MaxRequestSizeInMb, - ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB, - default_val=100) + self.settings[SettingsParameters.LogRequestContext].lower() != "false" + ) + call_context_state = ( + "enabled" + if self.settings[SettingsParameters.LogRequestContext] + else "disabled" + ) + logger.info(f"Call context logging is {call_context_state}") + + set_parameter( + SettingsParameters.MaxRequestSizeInMb, + ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB, + default_val=100, + ) def _validate_transfer_protocol_settings(self): if SettingsParameters.TransferProtocol not in self.settings: - msg = 'Missing transfer protocol information.' + msg = "Missing transfer protocol information." logger.critical(msg) raise RuntimeError(msg) protocol = self.settings[SettingsParameters.TransferProtocol] - if protocol == 'http': + if protocol == "http": return - if protocol != 'https': - msg = f'Unsupported transfer protocol: {protocol}' + if protocol != "https": + msg = f"Unsupported transfer protocol: {protocol}" logger.critical(msg) raise RuntimeError(msg) self._validate_cert_key_state( - 'The parameter(s) {} must be set.', + "The parameter(s) {} must be set.", SettingsParameters.CertificateFile in self.settings, - SettingsParameters.KeyFile in self.settings) + SettingsParameters.KeyFile in self.settings, + ) cert = self.settings[SettingsParameters.CertificateFile] self._validate_cert_key_state( - 'The parameter(s) {} must point to ' - 'an existing file.', + "The parameter(s) {} must point to " "an existing file.", os.path.isfile(cert), - os.path.isfile(self.settings[SettingsParameters.KeyFile])) + os.path.isfile(self.settings[SettingsParameters.KeyFile]), + ) tabpy.tabpy_server.app.util.validate_cert(cert) @staticmethod def _validate_cert_key_state(msg, cert_valid, key_valid): cert_and_key_param = ( - f'{ConfigParameters.TABPY_CERTIFICATE_FILE} and ' - f'{ConfigParameters.TABPY_KEY_FILE}') - https_error = 'Error using HTTPS: ' + f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and " + f"{ConfigParameters.TABPY_KEY_FILE}" + ) + https_error = "Error using HTTPS: " err = None if not cert_valid and not key_valid: err = https_error + msg.format(cert_and_key_param) elif not cert_valid: - err = https_error + \ - msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE) + err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE) elif not key_valid: err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE) @@ -380,10 +423,11 @@ def _validate_cert_key_state(msg, cert_valid, key_valid): def _parse_pwd_file(self): succeeded, self.credentials = parse_pwd_file( - self.settings[ConfigParameters.TABPY_PWD_FILE]) + self.settings[ConfigParameters.TABPY_PWD_FILE] + ) if succeeded and len(self.credentials) == 0: - logger.error('No credentials found') + logger.error("No credentials found") succeeded = False return succeeded @@ -393,8 +437,9 @@ def _get_features(self): # Check for auth if ConfigParameters.TABPY_PWD_FILE in self.settings: - features['authentication'] = { - 'required': True, 'methods': { - 'basic-auth': {}}} + features["authentication"] = { + "required": True, + "methods": {"basic-auth": {}}, + } return features diff --git a/tabpy/tabpy_server/app/util.py b/tabpy/tabpy_server/app/util.py index f95ebfe8..944b5997 100644 --- a/tabpy/tabpy_server/app/util.py +++ b/tabpy/tabpy_server/app/util.py @@ -9,33 +9,29 @@ def validate_cert(cert_file_path): - with open(cert_file_path, 'r') as f: + with open(cert_file_path, "r") as f: cert_buf = f.read() cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) - date_format, encoding = '%Y%m%d%H%M%SZ', 'ascii' - not_before = datetime.strptime( - cert.get_notBefore().decode(encoding), date_format) - not_after = datetime.strptime( - cert.get_notAfter().decode(encoding), date_format) + date_format, encoding = "%Y%m%d%H%M%SZ", "ascii" + not_before = datetime.strptime(cert.get_notBefore().decode(encoding), date_format) + not_after = datetime.strptime(cert.get_notAfter().decode(encoding), date_format) now = datetime.now() - https_error = 'Error using HTTPS: ' + https_error = "Error using HTTPS: " if now < not_before: - msg = (https_error + - f'The certificate provided is not valid until {not_before}.') + msg = https_error + f"The certificate provided is not valid until {not_before}." logger.critical(msg) raise RuntimeError(msg) if now > not_after: - msg = (https_error + - f'The certificate provided expired on {not_after}.') + msg = https_error + f"The certificate provided expired on {not_after}." logger.critical(msg) raise RuntimeError(msg) def parse_pwd_file(pwd_file_name): - ''' + """ Parses passwords file and returns set of credentials. Parameters @@ -51,42 +47,41 @@ def parse_pwd_file(pwd_file_name): credentials : dict Credentials from the file. Empty if succeeded is False. - ''' - logger.info(f'Parsing passwords file {pwd_file_name}...') + """ + logger.info(f"Parsing passwords file {pwd_file_name}...") if not os.path.isfile(pwd_file_name): - logger.critical(f'Passwords file {pwd_file_name} not found') + logger.critical(f"Passwords file {pwd_file_name} not found") return False, {} credentials = {} with open(pwd_file_name) as pwd_file: - pwd_file_reader = csv.reader(pwd_file, delimiter=' ') + pwd_file_reader = csv.reader(pwd_file, delimiter=" ") for row in pwd_file_reader: # skip empty lines if len(row) == 0: continue # skip commented lines - if row[0][0] == '#': + if row[0][0] == "#": continue if len(row) != 2: - logger.error( - f'Incorrect entry "{row}" in password file') + logger.error(f'Incorrect entry "{row}" in password file') return False, {} login = row[0].lower() if login in credentials: logger.error( - f'Multiple entries for username {login} ' - 'in password file') + f"Multiple entries for username {login} in password file" + ) return False, {} - if(len(row[1]) > 0): + if len(row[1]) > 0: credentials[login] = row[1] - logger.debug(f'Found username {login}') + logger.debug(f"Found username {login}") else: - logger.warning(f'Found username {row[0]} but no password') + logger.warning(f"Found username {row[0]} but no password") return False, {} logger.info("Authentication is enabled") diff --git a/tabpy/tabpy_server/common/endpoint_file_mgr.py b/tabpy/tabpy_server/common/endpoint_file_mgr.py index 65efb97c..6b7fed00 100644 --- a/tabpy/tabpy_server/common/endpoint_file_mgr.py +++ b/tabpy/tabpy_server/common/endpoint_file_mgr.py @@ -1,4 +1,4 @@ -''' +""" This module provides functionality required for managing endpoint objects in TabPy. It provides a way to download endpoint files from remote and then properly cleanup local the endpoint files on update/remove of endpoint @@ -7,40 +7,42 @@ The local temporary files for TabPy will by default located at /tmp/query_objects -''' +""" import logging import os import shutil from re import compile as _compile -_name_checker = _compile(r'^[a-zA-Z0-9-_\s]+$') +_name_checker = _compile(r"^[a-zA-Z0-9-_\s]+$") def _check_endpoint_name(name, logger=logging.getLogger(__name__)): """Checks that the endpoint name is valid by comparing it with an RE and checking that it is not reserved.""" if not isinstance(name, str): - msg = 'Endpoint name must be a string' + msg = "Endpoint name must be a string" logger.log(logging.CRITICAL, msg) raise TypeError(msg) - if name == '': - msg = 'Endpoint name cannot be empty' + if name == "": + msg = "Endpoint name cannot be empty" logger.log(logging.CRITICAL, msg) raise ValueError(msg) if not _name_checker.match(name): - msg = ('Endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.') + msg = ( + "Endpoint name can only contain: a-z, A-Z, 0-9," + " underscore, hyphens and spaces." + ) logger.log(logging.CRITICAL, msg) raise ValueError(msg) def grab_files(directory): - ''' + """ Generator that returns all files in a directory. - ''' + """ if not os.path.isdir(directory): return else: @@ -53,10 +55,10 @@ def grab_files(directory): yield full_path -def cleanup_endpoint_files(name, query_path, - logger=logging.getLogger(__name__), - retain_versions=None): - ''' +def cleanup_endpoint_files( + name, query_path, logger=logging.getLogger(__name__), retain_versions=None +): + """ Cleanup the disk space a certain endpiont uses. Parameters @@ -68,7 +70,7 @@ def cleanup_endpoint_files(name, query_path, If given, then all files for this endpoint are removed except the folder for the given version, otherwise, all files for that endpoint are removed. - ''' + """ _check_endpoint_name(name, logger=logger) local_dir = os.path.join(query_path, name) @@ -81,12 +83,12 @@ def cleanup_endpoint_files(name, query_path, if not retain_versions: shutil.rmtree(local_dir) else: - retain_folders = [os.path.join(local_dir, str(version)) - for version in retain_versions] - logger.log(logging.INFO, f'Retain folders: {retain_folders}') + retain_folders = [ + os.path.join(local_dir, str(version)) for version in retain_versions + ] + logger.log(logging.INFO, f"Retain folders: {retain_folders}") for file_or_dir in os.listdir(local_dir): candidate_dir = os.path.join(local_dir, file_or_dir) - if os.path.isdir(candidate_dir) and ( - candidate_dir not in retain_folders): + if os.path.isdir(candidate_dir) and (candidate_dir not in retain_folders): shutil.rmtree(candidate_dir) diff --git a/tabpy/tabpy_server/common/messages.py b/tabpy/tabpy_server/common/messages.py index 25c05d14..ad684319 100644 --- a/tabpy/tabpy_server/common/messages.py +++ b/tabpy/tabpy_server/common/messages.py @@ -16,13 +16,14 @@ class Msg: operator (*) that we inherit from namedtuple is also convenient. We empty __slots__ to avoid unnecessary overhead. """ + __metaclass__ = ABCMeta @abc.abstractmethod def for_json(self): d = self._asdict() type_str = self.__class__.__name__ - d.update({'type': type_str}) + d.update({"type": type_str}) return d @abc.abstractmethod @@ -32,130 +33,140 @@ def to_json(self): @staticmethod def from_json(str): d = json.loads(str) - type_str = d['type'] - del d['type'] + type_str = d["type"] + del d["type"] return eval(type_str)(**d) -class LoadSuccessful(namedtuple('LoadSuccessful', [ - 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): +class LoadSuccessful( + namedtuple( + "LoadSuccessful", ["uri", "path", "version", "is_update", "endpoint_type"] + ), + Msg, +): __slots__ = () -class LoadFailed(namedtuple('LoadFailed', [ - 'uri', 'version', 'error_msg']), Msg): +class LoadFailed(namedtuple("LoadFailed", ["uri", "version", "error_msg"]), Msg): __slots__ = () -class LoadInProgress(namedtuple('LoadInProgress', [ - 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): +class LoadInProgress( + namedtuple( + "LoadInProgress", ["uri", "path", "version", "is_update", "endpoint_type"] + ), + Msg, +): __slots__ = () -class Query(namedtuple('Query', ['uri', 'params']), Msg): +class Query(namedtuple("Query", ["uri", "params"]), Msg): __slots__ = () -class QuerySuccessful(namedtuple( - 'QuerySuccessful', ['uri', 'version', 'response']), Msg): +class QuerySuccessful( + namedtuple("QuerySuccessful", ["uri", "version", "response"]), Msg +): __slots__ = () -class LoadObject(namedtuple('LoadObject', [ - 'uri', 'url', 'version', 'is_update', 'endpoint_type']), Msg): +class LoadObject( + namedtuple("LoadObject", ["uri", "url", "version", "is_update", "endpoint_type"]), + Msg, +): __slots__ = () -class DeleteObjects(namedtuple('DeleteObjects', ['uris']), Msg): +class DeleteObjects(namedtuple("DeleteObjects", ["uris"]), Msg): __slots__ = () # Used for testing to flush out objects -class FlushObjects(namedtuple('FlushObjects', []), Msg): +class FlushObjects(namedtuple("FlushObjects", []), Msg): __slots__ = () -class ObjectsDeleted(namedtuple('ObjectsDeleted', ['uris']), Msg): +class ObjectsDeleted(namedtuple("ObjectsDeleted", ["uris"]), Msg): __slots__ = () -class ObjectsFlushed(namedtuple( - 'ObjectsFlushed', ['n_before', 'n_after']), Msg): +class ObjectsFlushed(namedtuple("ObjectsFlushed", ["n_before", "n_after"]), Msg): __slots__ = () -class CountObjects(namedtuple('CountObjects', []), Msg): +class CountObjects(namedtuple("CountObjects", []), Msg): __slots__ = () -class ObjectCount(namedtuple('ObjectCount', ['count']), Msg): +class ObjectCount(namedtuple("ObjectCount", ["count"]), Msg): __slots__ = () -class ListObjects(namedtuple('ListObjects', []), Msg): +class ListObjects(namedtuple("ListObjects", []), Msg): __slots__ = () -class ObjectList(namedtuple('ObjectList', ['objects']), Msg): +class ObjectList(namedtuple("ObjectList", ["objects"]), Msg): __slots__ = () -class UnknownURI(namedtuple('UnknownURI', ['uri']), Msg): +class UnknownURI(namedtuple("UnknownURI", ["uri"]), Msg): __slots__ = () -class UnknownMessage(namedtuple('UnknownMessage', ['msg']), Msg): +class UnknownMessage(namedtuple("UnknownMessage", ["msg"]), Msg): __slots__ = () -class DownloadSkipped(namedtuple('DownloadSkipped', [ - 'uri', 'version', 'msg', 'host']), Msg): +class DownloadSkipped( + namedtuple("DownloadSkipped", ["uri", "version", "msg", "host"]), Msg +): __slots__ = () -class QueryFailed(namedtuple('QueryFailed', ['uri', 'error']), Msg): +class QueryFailed(namedtuple("QueryFailed", ["uri", "error"]), Msg): __slots__ = () -class QueryError(namedtuple('QueryError', ['uri', 'error']), Msg): +class QueryError(namedtuple("QueryError", ["uri", "error"]), Msg): __slots__ = () -class CheckHealth(namedtuple('CheckHealth', []), Msg): +class CheckHealth(namedtuple("CheckHealth", []), Msg): __slots__ = () -class Healthy(namedtuple('Healthy', []), Msg): +class Healthy(namedtuple("Healthy", []), Msg): __slots__ = () -class Unhealthy(namedtuple('Unhealthy', []), Msg): +class Unhealthy(namedtuple("Unhealthy", []), Msg): __slots__ = () -class Ping(namedtuple('Ping', ['id']), Msg): +class Ping(namedtuple("Ping", ["id"]), Msg): __slots__ = () -class Pong(namedtuple('Pong', ['id']), Msg): +class Pong(namedtuple("Pong", ["id"]), Msg): __slots__ = () -class Listening(namedtuple('Listening', []), Msg): +class Listening(namedtuple("Listening", []), Msg): __slots__ = () -class EngineFailure(namedtuple('EngineFailure', ['error']), Msg): +class EngineFailure(namedtuple("EngineFailure", ["error"]), Msg): __slots__ = () -class FlushLogs(namedtuple('FlushLogs', []), Msg): +class FlushLogs(namedtuple("FlushLogs", []), Msg): __slots__ = () -class LogsFlushed(namedtuple('LogsFlushed', []), Msg): +class LogsFlushed(namedtuple("LogsFlushed", []), Msg): __slots__ = () -class ServiceError(namedtuple('ServiceError', ['error']), Msg): +class ServiceError(namedtuple("ServiceError", ["error"]), Msg): __slots__ = () diff --git a/tabpy/tabpy_server/common/util.py b/tabpy/tabpy_server/common/util.py index be53df94..c731450a 100644 --- a/tabpy/tabpy_server/common/util.py +++ b/tabpy/tabpy_server/common/util.py @@ -1,3 +1,3 @@ def format_exception(e, context): - err_msg = f'{e.__class__.__name__} : {str(e)}' + err_msg = f"{e.__class__.__name__} : {str(e)}" return err_msg diff --git a/tabpy/tabpy_server/handlers/__init__.py b/tabpy/tabpy_server/handlers/__init__.py index c73ea4d8..0c00cde6 100644 --- a/tabpy/tabpy_server/handlers/__init__.py +++ b/tabpy/tabpy_server/handlers/__init__.py @@ -4,10 +4,10 @@ from tabpy.tabpy_server.handlers.endpoint_handler import EndpointHandler from tabpy.tabpy_server.handlers.endpoints_handler import EndpointsHandler -from tabpy.tabpy_server.handlers.evaluation_plane_handler\ - import EvaluationPlaneHandler +from tabpy.tabpy_server.handlers.evaluation_plane_handler import EvaluationPlaneHandler from tabpy.tabpy_server.handlers.query_plane_handler import QueryPlaneHandler from tabpy.tabpy_server.handlers.service_info_handler import ServiceInfoHandler from tabpy.tabpy_server.handlers.status_handler import StatusHandler -from tabpy.tabpy_server.handlers.upload_destination_handler\ - import UploadDestinationHandler +from tabpy.tabpy_server.handlers.upload_destination_handler import ( + UploadDestinationHandler, +) diff --git a/tabpy/tabpy_server/handlers/base_handler.py b/tabpy/tabpy_server/handlers/base_handler.py index 2a13b36f..cc749227 100644 --- a/tabpy/tabpy_server/handlers/base_handler.py +++ b/tabpy/tabpy_server/handlers/base_handler.py @@ -13,9 +13,10 @@ class ContextLoggerWrapper: - ''' + """ This class appends request context to logged messages. - ''' + """ + @staticmethod def _generate_call_id(): return str(uuid.uuid4()) @@ -29,21 +30,21 @@ def __init__(self, request: tornado.httputil.HTTPServerRequest): self.request_context_logged = False def set_request(self, request: tornado.httputil.HTTPServerRequest): - ''' + """ Set HTTP(S) request for logger. Headers will be used to append request data as client information, Tableau user name, etc. - ''' + """ self.remote_ip = request.remote_ip self.method = request.method self.url = request.full_url() - if 'TabPy-Client' in request.headers: - self.client = request.headers['TabPy-Client'] + if "TabPy-Client" in request.headers: + self.client = request.headers["TabPy-Client"] else: self.client = None - if 'TabPy-User' in request.headers: - self.tableau_username = request.headers['TabPy-User'] + if "TabPy-User" in request.headers: + self.tableau_username = request.headers["TabPy-User"] else: self.tableau_username = None @@ -51,7 +52,7 @@ def set_tabpy_username(self, tabpy_username: str): self.tabpy_username = tabpy_username def enable_context_logging(self, enable: bool): - ''' + """ Enable/disable request context information logging. Parameters @@ -60,38 +61,38 @@ def enable_context_logging(self, enable: bool): If True request context information will be logged and every log entry for a request handler will have call ID with it. - ''' + """ self.log_request_context = enable def _log_context_info(self): if not self.log_request_context: return - context = f'Call ID: {self.call_id}' + context = f"Call ID: {self.call_id}" if self.remote_ip is not None: - context += f', Caller: {self.remote_ip}' + context += f", Caller: {self.remote_ip}" if self.method is not None: - context += f', Method: {self.method}' + context += f", Method: {self.method}" if self.url is not None: - context += f', URL: {self.url}' + context += f", URL: {self.url}" if self.client is not None: - context += f', Client: {self.client}' + context += f", Client: {self.client}" if self.tableau_username is not None: - context += f', Tableau user: {self.tableau_username}' + context += f", Tableau user: {self.tableau_username}" if self.tabpy_username is not None: - context += f', TabPy user: {self.tabpy_username}' + context += f", TabPy user: {self.tabpy_username}" logging.getLogger(__name__).log(logging.INFO, context) self.request_context_logged = True def log(self, level: int, msg: str): - ''' + """ Log message with or without call ID. If call context is logged and call ID added to any log entry is specified by if context logging is enabled (see CallContext.enable_context_logging for more details). @@ -109,13 +110,13 @@ def log(self, level: int, msg: str): kwargs Same as kwargs in Logger.debug(). - ''' + """ extended_msg = msg if self.log_request_context: if not self.request_context_logged: self._log_context_info() - extended_msg += f', <>' + extended_msg += f", <>" logging.getLogger(__name__).log(level, extended_msg) @@ -135,16 +136,14 @@ def initialize(self, app): self.logger = ContextLoggerWrapper(self.request) self.logger.enable_context_logging( - app.settings[SettingsParameters.LogRequestContext]) - self.logger.log( - logging.DEBUG, - 'Checking if need to handle authentication') + app.settings[SettingsParameters.LogRequestContext] + ) + self.logger.log(logging.DEBUG, "Checking if need to handle authentication") self.not_authorized = not self.handle_authentication("v1") def error_out(self, code, log_message, info=None): self.set_status(code) - self.write(json.dumps( - {'message': log_message, 'info': info or {}})) + self.write(json.dumps({"message": log_message, "info": info or {}})) # We want to duplicate error message in console for # loggers are misconfigured or causing the failure @@ -152,8 +151,10 @@ def error_out(self, code, log_message, info=None): print(info) self.logger.log( logging.ERROR, - 'Responding with status={}, message="{}", info="{}"'. - format(code, log_message, info)) + 'Responding with status={}, message="{}", info="{}"'.format( + code, log_message, info + ), + ) self.finish() def options(self): @@ -169,23 +170,20 @@ def _add_CORS_header(self): origin = self.tabpy_state.get_access_control_allow_origin() if len(origin) > 0: self.set_header("Access-Control-Allow-Origin", origin) - self.logger.log(logging.DEBUG, - f'Access-Control-Allow-Origin:{origin}') + self.logger.log(logging.DEBUG, f"Access-Control-Allow-Origin:{origin}") headers = self.tabpy_state.get_access_control_allow_headers() if len(headers) > 0: self.set_header("Access-Control-Allow-Headers", headers) - self.logger.log(logging.DEBUG, - f'Access-Control-Allow-Headers:{headers}') + self.logger.log(logging.DEBUG, f"Access-Control-Allow-Headers:{headers}") methods = self.tabpy_state.get_access_control_allow_methods() if len(methods) > 0: self.set_header("Access-Control-Allow-Methods", methods) - self.logger.log(logging.DEBUG, - f'Access-Control-Allow-Methods:{methods}') + self.logger.log(logging.DEBUG, f"Access-Control-Allow-Methods:{methods}") def _get_auth_method(self, api_version) -> (bool, str): - ''' + """ Finds authentication method if provided. Parameters @@ -205,49 +203,51 @@ def _get_auth_method(self, api_version) -> (bool, str): (True, '') as result of this function means authentication is not needed. - ''' + """ if api_version not in self.settings[SettingsParameters.ApiVersions]: - self.logger.log(logging.CRITICAL, - f'Unknown API version "{api_version}"') - return False, '' - - version_settings =\ - self.settings[SettingsParameters.ApiVersions][api_version] - if 'features' not in version_settings: - self.logger.log(logging.INFO, - f'No features configured for API "{api_version}"') - return True, '' - - features = version_settings['features'] - if 'authentication' not in features or\ - not features['authentication']['required']: + self.logger.log(logging.CRITICAL, f'Unknown API version "{api_version}"') + return False, "" + + version_settings = self.settings[SettingsParameters.ApiVersions][api_version] + if "features" not in version_settings: + self.logger.log( + logging.INFO, f'No features configured for API "{api_version}"' + ) + return True, "" + + features = version_settings["features"] + if ( + "authentication" not in features + or not features["authentication"]["required"] + ): self.logger.log( logging.INFO, - 'Authentication is not a required feature for API ' - f'"{api_version}"') - return True, '' + "Authentication is not a required feature for API " f'"{api_version}"', + ) + return True, "" - auth_feature = features['authentication'] - if 'methods' not in auth_feature: + auth_feature = features["authentication"] + if "methods" not in auth_feature: self.logger.log( logging.INFO, - 'Authentication method is not configured for API ' - f'"{api_version}"') + "Authentication method is not configured for API " f'"{api_version}"', + ) - methods = auth_feature['methods'] - if 'basic-auth' in auth_feature['methods']: - return True, 'basic-auth' + methods = auth_feature["methods"] + if "basic-auth" in auth_feature["methods"]: + return True, "basic-auth" # Add new methods here... # No known methods were found self.logger.log( logging.CRITICAL, f'Unknown authentication method(s) "{methods}" are configured ' - f'for API "{api_version}"') - return False, '' + f'for API "{api_version}"', + ) + return False, "" def _get_basic_auth_credentials(self) -> bool: - ''' + """ Find credentials for basic access authentication method. Credentials if found stored in Credentials.username and Credentials.password. @@ -256,32 +256,31 @@ def _get_basic_auth_credentials(self) -> bool: bool True if valid credentials were found. False otherwise. - ''' - self.logger.log(logging.DEBUG, - 'Checking request headers for authentication data') - if 'Authorization' not in self.request.headers: - self.logger.log(logging.INFO, 'Authorization header not found') + """ + self.logger.log( + logging.DEBUG, "Checking request headers for authentication data" + ) + if "Authorization" not in self.request.headers: + self.logger.log(logging.INFO, "Authorization header not found") return False - auth_header = self.request.headers['Authorization'] - auth_header_list = auth_header.split(' ') - if len(auth_header_list) != 2 or\ - auth_header_list[0] != 'Basic': - self.logger.log(logging.ERROR, - f'Unknown authentication method "{auth_header}"') + auth_header = self.request.headers["Authorization"] + auth_header_list = auth_header.split(" ") + if len(auth_header_list) != 2 or auth_header_list[0] != "Basic": + self.logger.log( + logging.ERROR, f'Unknown authentication method "{auth_header}"' + ) return False try: - cred = base64.b64decode(auth_header_list[1]).decode('utf-8') + cred = base64.b64decode(auth_header_list[1]).decode("utf-8") except (binascii.Error, UnicodeDecodeError) as ex: - self.logger.log(logging.CRITICAL, - f'Cannot decode credentials: {str(ex)}') + self.logger.log(logging.CRITICAL, f"Cannot decode credentials: {str(ex)}") return False - login_pwd = cred.split(':') + login_pwd = cred.split(":") if len(login_pwd) != 2: - self.logger.log(logging.ERROR, - 'Invalid string in encoded credentials') + self.logger.log(logging.ERROR, "Invalid string in encoded credentials") return False self.username = login_pwd[0] @@ -290,7 +289,7 @@ def _get_basic_auth_credentials(self) -> bool: return True def _get_credentials(self, method) -> bool: - ''' + """ Find credentials for specified authentication method. Credentials if found stored in self.username and self.password. @@ -304,8 +303,8 @@ def _get_credentials(self, method) -> bool: bool True if valid credentials were found. False otherwise. - ''' - if method == 'basic-auth': + """ + if method == "basic-auth": return self._get_basic_auth_credentials() # Add new methods here... @@ -313,11 +312,12 @@ def _get_credentials(self, method) -> bool: self.logger.log( logging.CRITICAL, f'Unknown authentication method(s) "{method}" are configured ' - f'for API "{api_version}"') + f'for API "{api_version}"', + ) return False def _validate_basic_auth_credentials(self) -> bool: - ''' + """ Validates username:pwd if they are the same as stored credentials. @@ -327,25 +327,26 @@ def _validate_basic_auth_credentials(self) -> bool: True if credentials has key login and credentials[login] equal SHA3(pwd), False otherwise. - ''' + """ login = self.username.lower() - self.logger.log(logging.DEBUG, - f'Validating credentials for user name "{login}"') + self.logger.log( + logging.DEBUG, f'Validating credentials for user name "{login}"' + ) if login not in self.credentials: - self.logger.log(logging.ERROR, - f'User name "{self.username}" not found') + self.logger.log(logging.ERROR, f'User name "{self.username}" not found') return False hashed_pwd = hash_password(login, self.password) if self.credentials[login].lower() != hashed_pwd.lower(): - self.logger.log(logging.ERROR, - f'Wrong password for user name "{self.username}"') + self.logger.log( + logging.ERROR, f'Wrong password for user name "{self.username}"' + ) return False return True def _validate_credentials(self, method) -> bool: - ''' + """ Validates credentials according to specified methods if they are what expected. @@ -359,8 +360,8 @@ def _validate_credentials(self, method) -> bool: bool True if credentials are valid. False otherwise. - ''' - if method == 'basic-auth': + """ + if method == "basic-auth": return self._validate_basic_auth_credentials() # Add new methods here... @@ -368,11 +369,12 @@ def _validate_credentials(self, method) -> bool: self.logger.log( logging.CRITICAL, f'Unknown authentication method(s) "{method}" are configured ' - f'for API "{api_version}"') + f'for API "{api_version}"', + ) return False def handle_authentication(self, api_version) -> bool: - ''' + """ If authentication feature is configured checks provided credentials. @@ -388,13 +390,13 @@ def handle_authentication(self, api_version) -> bool: True if authentication is required and valid credentials provided. False otherwise. - ''' - self.logger.log(logging.DEBUG, 'Handling authentication') + """ + self.logger.log(logging.DEBUG, "Handling authentication") found, method = self._get_auth_method(api_version) if not found: return False - if method == '': + if method == "": # Do not validate credentials return True @@ -404,7 +406,7 @@ def handle_authentication(self, api_version) -> bool: return self._validate_credentials(method) def should_fail_with_not_authorized(self): - ''' + """ Checks if authentication is required: - if it is not returns false, None - if it is required validates provided credentials @@ -415,20 +417,18 @@ def should_fail_with_not_authorized(self): False if authentication is not required or is required and validation for credentials passes. True if validation for credentials failed. - ''' + """ return self.not_authorized def fail_with_not_authorized(self): - ''' + """ Prepares server 401 response. - ''' - self.logger.log( - logging.ERROR, - 'Failing with 401 for unauthorized request') + """ + self.logger.log(logging.ERROR, "Failing with 401 for unauthorized request") self.set_status(401) - self.set_header('WWW-Authenticate', - f'Basic realm="{self.tabpy_state.name}"') + self.set_header("WWW-Authenticate", f'Basic realm="{self.tabpy_state.name}"') self.error_out( 401, info="Unauthorized request.", - log_message="Invalid credentials provided.") + log_message="Invalid credentials provided.", + ) diff --git a/tabpy/tabpy_server/handlers/endpoint_handler.py b/tabpy/tabpy_server/handlers/endpoint_handler.py index 20b6b334..d15689e0 100644 --- a/tabpy/tabpy_server/handlers/endpoint_handler.py +++ b/tabpy/tabpy_server/handlers/endpoint_handler.py @@ -1,10 +1,10 @@ -''' +""" HTTP handeler to serve specific endpoint request like http://myserver:9004/endpoints/mymodel For how generic endpoints requests is served look at endpoints_handler.py -''' +""" import concurrent import json @@ -28,19 +28,20 @@ def get(self, endpoint_name): self.fail_with_not_authorized() return - self.logger.log(logging.DEBUG, - f'Processing GET for /endpoints/{endpoint_name}') + self.logger.log(logging.DEBUG, f"Processing GET for /endpoints/{endpoint_name}") self._add_CORS_header() if not endpoint_name: self.write(json.dumps(self.tabpy_state.get_endpoints())) else: if endpoint_name in self.tabpy_state.get_endpoints(): - self.write(json.dumps( - self.tabpy_state.get_endpoints()[endpoint_name])) + self.write(json.dumps(self.tabpy_state.get_endpoints()[endpoint_name])) else: - self.error_out(404, 'Unknown endpoint', - info=f'Endpoint {endpoint_name} is not found') + self.error_out( + 404, + "Unknown endpoint", + info=f"Endpoint {endpoint_name} is not found", + ) @gen.coroutine def put(self, name): @@ -48,8 +49,7 @@ def put(self, name): self.fail_with_not_authorized() return - self.logger.log(logging.DEBUG, - f'Processing PUT for /endpoints/{name}') + self.logger.log(logging.DEBUG, f"Processing PUT for /endpoints/{name}") try: if not self.request.body: @@ -57,30 +57,26 @@ def put(self, name): self.finish() return try: - request_data = json.loads( - self.request.body.decode('utf-8')) + request_data = json.loads(self.request.body.decode("utf-8")) except BaseException as ex: self.error_out( - 400, - log_message="Failed to decode input body", - info=str(ex)) + 400, log_message="Failed to decode input body", info=str(ex) + ) self.finish() return # check if endpoint exists endpoints = self.tabpy_state.get_endpoints(name) if len(endpoints) == 0: - self.error_out(404, - f'endpoint {name} does not exist.') + self.error_out(404, f"endpoint {name} does not exist.") self.finish() return - new_version = int(endpoints[name]['version']) + 1 - self.logger.log( - logging.INFO, - f'Endpoint info: {request_data}') + new_version = int(endpoints[name]["version"]) + 1 + self.logger.log(logging.INFO, f"Endpoint info: {request_data}") err_msg = yield self._add_or_update_endpoint( - 'update', name, new_version, request_data) + "update", name, new_version, request_data + ) if err_msg: self.error_out(400, err_msg) self.finish() @@ -89,7 +85,7 @@ def put(self, name): self.finish() except Exception as e: - err_msg = format_exception(e, 'update_endpoint') + err_msg = format_exception(e, "update_endpoint") self.error_out(500, err_msg) self.finish() @@ -99,15 +95,12 @@ def delete(self, name): self.fail_with_not_authorized() return - self.logger.log( - logging.DEBUG, - f'Processing DELETE for /endpoints/{name}') + self.logger.log(logging.DEBUG, f"Processing DELETE for /endpoints/{name}") try: endpoints = self.tabpy_state.get_endpoints(name) if len(endpoints) == 0: - self.error_out(404, - f'endpoint {name} does not exist.') + self.error_out(404, f"endpoint {name} does not exist.") self.finish() return @@ -115,20 +108,19 @@ def delete(self, name): try: endpoint_info = self.tabpy_state.delete_endpoint(name) except Exception as e: - self.error_out(400, - f'Error when removing endpoint: {e.message}') + self.error_out(400, f"Error when removing endpoint: {e.message}") self.finish() return # delete files - if endpoint_info['type'] != 'alias': + if endpoint_info["type"] != "alias": delete_path = get_query_object_path( - self.settings['state_file_path'], name, None) + self.settings["state_file_path"], name, None + ) try: yield self._delete_po_future(delete_path) except Exception as e: - self.error_out(400, - f'Error while deleting: {e}') + self.error_out(400, f"Error while deleting: {e}") self.finish() return @@ -136,12 +128,13 @@ def delete(self, name): self.finish() except Exception as e: - err_msg = format_exception(e, 'delete endpoint') + err_msg = format_exception(e, "delete endpoint") self.error_out(500, err_msg) self.finish() - on_state_change(self.settings, self.tabpy_state, self.python_service, - self.logger) + on_state_change( + self.settings, self.tabpy_state, self.python_service, self.logger + ) @gen.coroutine def _delete_po_future(self, delete_path): diff --git a/tabpy/tabpy_server/handlers/endpoints_handler.py b/tabpy/tabpy_server/handlers/endpoints_handler.py index bd54311b..bd269d3a 100644 --- a/tabpy/tabpy_server/handlers/endpoints_handler.py +++ b/tabpy/tabpy_server/handlers/endpoints_handler.py @@ -1,10 +1,10 @@ -''' +""" HTTP handeler to serve general endpoints request, specifically http://myserver:9004/endpoints For how individual endpoint requests are served look at endpoint_handler.py -''' +""" import json import logging @@ -39,48 +39,38 @@ def post(self): return try: - request_data = json.loads( - self.request.body.decode('utf-8')) + request_data = json.loads(self.request.body.decode("utf-8")) except Exception as ex: - self.error_out( - 400, - "Failed to decode input body", - str(ex)) + self.error_out(400, "Failed to decode input body", str(ex)) self.finish() return - if 'name' not in request_data: - self.error_out(400, - "name is required to add an endpoint.") + if "name" not in request_data: + self.error_out(400, "name is required to add an endpoint.") self.finish() return - name = request_data['name'] + name = request_data["name"] # check if endpoint already exist if name in self.tabpy_state.get_endpoints(): - self.error_out(400, f'endpoint {name} already exists.') + self.error_out(400, f"endpoint {name} already exists.") self.finish() return - self.logger.log( - logging.DEBUG, - f'Adding endpoint "{name}"') - err_msg = yield self._add_or_update_endpoint('add', name, 1, - request_data) + self.logger.log(logging.DEBUG, f'Adding endpoint "{name}"') + err_msg = yield self._add_or_update_endpoint("add", name, 1, request_data) if err_msg: self.error_out(400, err_msg) else: - self.logger.log( - logging.DEBUG, - f'Endpoint {name} successfully added') + self.logger.log(logging.DEBUG, f"Endpoint {name} successfully added") self.set_status(201) self.write(self.tabpy_state.get_endpoints(name)) self.finish() return except Exception as e: - err_msg = format_exception(e, '/add_endpoint') + err_msg = format_exception(e, "/add_endpoint") self.error_out(500, "error adding endpoint", err_msg) self.finish() return diff --git a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py index f3d799f6..ae4a55f3 100644 --- a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py +++ b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py @@ -15,27 +15,29 @@ def __init__(self, protocol, port, logger, timeout): self.timeout = timeout def query(self, name, *args, **kwargs): - url = f'{self.protocol}://localhost:{self.port}/query/{name}' - self.logger.log(logging.DEBUG, f'Querying {url}...') - internal_data = {'data': args or kwargs} + url = f"{self.protocol}://localhost:{self.port}/query/{name}" + self.logger.log(logging.DEBUG, f"Querying {url}...") + internal_data = {"data": args or kwargs} data = json.dumps(internal_data) - headers = {'content-type': 'application/json'} - response = requests.post(url=url, data=data, headers=headers, - timeout=self.timeout, - verify=False) + headers = {"content-type": "application/json"} + response = requests.post( + url=url, data=data, headers=headers, timeout=self.timeout, verify=False + ) return response.json() class EvaluationPlaneHandler(BaseHandler): - ''' + """ EvaluationPlaneHandler is responsible for running arbitrary python scripts. - ''' + """ def initialize(self, executor, app): super(EvaluationPlaneHandler, self).initialize(app) self.executor = executor - self._error_message_timeout = f'User defined script timed out. ' \ - f'Timeout is set to {self.eval_timeout} s.' + self._error_message_timeout = ( + f"User defined script timed out. " + f"Timeout is set to {self.eval_timeout} s." + ) @gen.coroutine def post(self): @@ -45,88 +47,90 @@ def post(self): self._add_CORS_header() try: - body = json.loads(self.request.body.decode('utf-8')) - if 'script' not in body: - self.error_out(400, 'Script is empty.') + body = json.loads(self.request.body.decode("utf-8")) + if "script" not in body: + self.error_out(400, "Script is empty.") return # Transforming user script into a proper function. - user_code = body['script'] + user_code = body["script"] arguments = None - arguments_str = '' - if 'data' in body: - arguments = body['data'] + arguments_str = "" + if "data" in body: + arguments = body["data"] if arguments is not None: if not isinstance(arguments, dict): - self.error_out(400, 'Script parameters need to be ' - 'provided as a dictionary.') + self.error_out( + 400, "Script parameters need to be provided as a dictionary." + ) return else: arguments_expected = [] for i in range(1, len(arguments.keys()) + 1): - arguments_expected.append('_arg' + str(i)) + arguments_expected.append("_arg" + str(i)) if sorted(arguments_expected) == sorted(arguments.keys()): - arguments_str = ', ' + ', '.join(arguments.keys()) + arguments_str = ", " + ", ".join(arguments.keys()) else: - self.error_out(400, 'Variables names should follow ' - 'the format _arg1, _arg2, _argN') + self.error_out( + 400, + "Variables names should follow " + "the format _arg1, _arg2, _argN", + ) return - function_to_evaluate = f'def _user_script(tabpy{arguments_str}):\n' + function_to_evaluate = f"def _user_script(tabpy{arguments_str}):\n" for u in user_code.splitlines(): - function_to_evaluate += ' ' + u + '\n' + function_to_evaluate += " " + u + "\n" self.logger.log( - logging.INFO, - f'function to evaluate={function_to_evaluate}') + logging.INFO, f"function to evaluate={function_to_evaluate}" + ) try: - result = yield self._call_subprocess(function_to_evaluate, - arguments) - except (gen.TimeoutError, - requests.exceptions.ConnectTimeout, - requests.exceptions.ReadTimeout): + result = yield self._call_subprocess(function_to_evaluate, arguments) + except ( + gen.TimeoutError, + requests.exceptions.ConnectTimeout, + requests.exceptions.ReadTimeout, + ): self.logger.log(logging.ERROR, self._error_message_timeout) self.error_out(408, self._error_message_timeout) return if result is None: - self.error_out(400, 'Error running script. No return value') + self.error_out(400, "Error running script. No return value") else: self.write(json.dumps(result)) self.finish() except Exception as e: - err_msg = f'{e.__class__.__name__} : {str(e)}' + err_msg = f"{e.__class__.__name__} : {str(e)}" if err_msg != "KeyError : 'response'": - err_msg = format_exception(e, 'POST /evaluate') - self.error_out(500, 'Error processing script', info=err_msg) + err_msg = format_exception(e, "POST /evaluate") + self.error_out(500, "Error processing script", info=err_msg) else: self.error_out( 404, - 'Error processing script', + "Error processing script", info="The endpoint you're " "trying to query did not respond. Please make sure the " "endpoint exists and the correct set of arguments are " - "provided.") + "provided.", + ) @gen.coroutine def _call_subprocess(self, function_to_evaluate, arguments): restricted_tabpy = RestrictedTabPy( - self.protocol, - self.port, - self.logger, - self.eval_timeout) + self.protocol, self.port, self.logger, self.eval_timeout + ) # Exec does not run the function, so it does not block. exec(function_to_evaluate, globals()) if arguments is None: future = self.executor.submit(_user_script, restricted_tabpy) else: - future = self.executor.submit(_user_script, restricted_tabpy, - **arguments) + future = self.executor.submit(_user_script, restricted_tabpy, **arguments) - ret = yield gen.with_timeout(timedelta(seconds=self.eval_timeout), - future) + ret = yield gen.with_timeout(timedelta(seconds=self.eval_timeout), future) raise gen.Return(ret) diff --git a/tabpy/tabpy_server/handlers/main_handler.py b/tabpy/tabpy_server/handlers/main_handler.py index 9961da2e..dbf2680b 100644 --- a/tabpy/tabpy_server/handlers/main_handler.py +++ b/tabpy/tabpy_server/handlers/main_handler.py @@ -4,4 +4,4 @@ class MainHandler(BaseHandler): def get(self): self._add_CORS_header() - self.render('/static/index.html') + self.render("/static/index.html") diff --git a/tabpy/tabpy_server/handlers/management_handler.py b/tabpy/tabpy_server/handlers/management_handler.py index e7a0b3c0..5b056657 100644 --- a/tabpy/tabpy_server/handlers/management_handler.py +++ b/tabpy/tabpy_server/handlers/management_handler.py @@ -42,82 +42,84 @@ def initialize(self, app): self.port = self.settings[SettingsParameters.Port] def _get_protocol(self): - return 'http://' + return "http://" @gen.coroutine def _add_or_update_endpoint(self, action, name, version, request_data): - ''' + """ Add or update an endpoint - ''' - self.logger.log(logging.DEBUG, f'Adding/updating model {name}...') + """ + self.logger.log(logging.DEBUG, f"Adding/updating model {name}...") - _name_checker = _compile(r'^[a-zA-Z0-9-_\s]+$') + _name_checker = _compile(r"^[a-zA-Z0-9-_\s]+$") if not isinstance(name, str): - msg = 'Endpoint name must be a string' + msg = "Endpoint name must be a string" self.logger.log(logging.CRITICAL, msg) raise TypeError(msg) if not _name_checker.match(name): - raise gen.Return('endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.') - - if self.settings.get('add_or_updating_endpoint'): - msg = ('Another endpoint update is already in progress' - ', please wait a while and try again') + raise gen.Return( + "endpoint name can only contain: a-z, A-Z, 0-9," + " underscore, hyphens and spaces." + ) + + if self.settings.get("add_or_updating_endpoint"): + msg = ( + "Another endpoint update is already in progress" + ", please wait a while and try again" + ) self.logger.log(logging.CRITICAL, msg) raise RuntimeError(msg) request_uuid = random_uuid() - self.settings['add_or_updating_endpoint'] = request_uuid + self.settings["add_or_updating_endpoint"] = request_uuid try: - description = (request_data['description'] - if 'description' in request_data else None) - if 'docstring' in request_data: - docstring = str(bytes(request_data['docstring'], - "utf-8").decode('unicode_escape')) + description = ( + request_data["description"] if "description" in request_data else None + ) + if "docstring" in request_data: + docstring = str( + bytes(request_data["docstring"], "utf-8").decode("unicode_escape") + ) else: docstring = None - endpoint_type = (request_data['type'] if 'type' in request_data - else None) - methods = (request_data['methods'] if 'methods' in request_data - else []) - dependencies = (request_data['dependencies'] - if 'dependencies' in request_data else None) - target = (request_data['target'] - if 'target' in request_data else None) - schema = (request_data['schema'] if 'schema' in request_data - else None) - - src_path = (request_data['src_path'] if 'src_path' in request_data - else None) + endpoint_type = request_data["type"] if "type" in request_data else None + methods = request_data["methods"] if "methods" in request_data else [] + dependencies = ( + request_data["dependencies"] if "dependencies" in request_data else None + ) + target = request_data["target"] if "target" in request_data else None + schema = request_data["schema"] if "schema" in request_data else None + + src_path = request_data["src_path"] if "src_path" in request_data else None target_path = get_query_object_path( - self.settings[SettingsParameters.StateFilePath], name, version) - self.logger.log(logging.DEBUG, - f'Checking source path {src_path}...') - _path_checker = _compile(r'^[\\\:a-zA-Z0-9-_~\s/\.\(\)]+$') + self.settings[SettingsParameters.StateFilePath], name, version + ) + self.logger.log(logging.DEBUG, f"Checking source path {src_path}...") + _path_checker = _compile(r"^[\\\:a-zA-Z0-9-_~\s/\.\(\)]+$") # copy from staging if src_path: - if not isinstance(request_data['src_path'], str): + if not isinstance(request_data["src_path"], str): raise gen.Return("src_path must be a string.") if not _path_checker.match(src_path): raise gen.Return( - 'Endpoint source path name can only contain: ' - 'a-z, A-Z, 0-9, underscore, hyphens and spaces.') + "Endpoint source path name can only contain: " + "a-z, A-Z, 0-9, underscore, hyphens and spaces." + ) yield self._copy_po_future(src_path, target_path) - elif endpoint_type != 'alias': - raise gen.Return("src_path is required to add/update an " - "endpoint.") + elif endpoint_type != "alias": + raise gen.Return("src_path is required to add/update an " "endpoint.") # alias special logic: - if endpoint_type == 'alias': + if endpoint_type == "alias": if not target: - raise gen.Return('Target is required for alias endpoint.') + raise gen.Return("Target is required for alias endpoint.") dependencies = [target] # update local config try: - if action == 'add': + if action == "add": self.tabpy_state.add_endpoint( name=name, description=description, @@ -126,7 +128,8 @@ def _add_or_update_endpoint(self, action, name, version, request_data): methods=methods, dependencies=dependencies, target=target, - schema=schema) + schema=schema, + ) else: self.tabpy_state.update_endpoint( name=name, @@ -137,22 +140,23 @@ def _add_or_update_endpoint(self, action, name, version, request_data): dependencies=dependencies, target=target, schema=schema, - version=version) + version=version, + ) except Exception as e: - raise gen.Return(f'Error when changing TabPy state: {e}') + raise gen.Return(f"Error when changing TabPy state: {e}") - on_state_change(self.settings, - self.tabpy_state, - self.python_service, - self.logger) + on_state_change( + self.settings, self.tabpy_state, self.python_service, self.logger + ) finally: - self.settings['add_or_updating_endpoint'] = None + self.settings["add_or_updating_endpoint"] = None @gen.coroutine def _copy_po_future(self, src_path, target_path): - future = STAGING_THREAD.submit(copy_from_local, src_path, - target_path, is_dir=True) + future = STAGING_THREAD.submit( + copy_from_local, src_path, target_path, is_dir=True + ) ret = yield future raise gen.Return(ret) diff --git a/tabpy/tabpy_server/handlers/query_plane_handler.py b/tabpy/tabpy_server/handlers/query_plane_handler.py index 61657a39..603968b9 100644 --- a/tabpy/tabpy_server/handlers/query_plane_handler.py +++ b/tabpy/tabpy_server/handlers/query_plane_handler.py @@ -2,7 +2,11 @@ import logging import time from tabpy.tabpy_server.common.messages import ( - Query, QuerySuccessful, QueryError, UnknownURI) + Query, + QuerySuccessful, + QueryError, + UnknownURI, +) from hashlib import md5 import uuid import json @@ -45,22 +49,19 @@ def _query(self, po_name, data, uid, qry): as a dictionary, and the time in seconds that it took to complete the request. """ - self.logger.log(logging.DEBUG, - f'Collecting query info for {po_name}...') + self.logger.log(logging.DEBUG, f"Collecting query info for {po_name}...") start_time = time.time() response = self.python_service.ps.query(po_name, data, uid) gls_time = time.time() - start_time - self.logger.log(logging.DEBUG, f'Query info: {response}') + self.logger.log(logging.DEBUG, f"Query info: {response}") if isinstance(response, QuerySuccessful): response_json = response.to_json() - md5_tag = md5(response_json.encode('utf-8')).hexdigest() + md5_tag = md5(response_json.encode("utf-8")).hexdigest() self.set_header("Etag", f'"{md5_tag}"') return (QuerySuccessful, response.for_json(), gls_time) else: - self.logger.log( - logging.ERROR, - f'Failed query, response: {response}') + self.logger.log(logging.ERROR, f"Failed query, response: {response}") return (type(response), response.for_json(), gls_time) # handle HTTP Options requests to support CORS @@ -71,43 +72,45 @@ def options(self, pred_name): self.fail_with_not_authorized() return - self.logger.log( - logging.DEBUG, - f'Processing OPTIONS for /query/{pred_name}') + self.logger.log(logging.DEBUG, f"Processing OPTIONS for /query/{pred_name}") # add CORS headers if TabPy has a cors_origin specified self._add_CORS_header() self.write({}) def _handle_result(self, po_name, data, qry, uid): - (response_type, response, gls_time) = \ - self._query(po_name, data, uid, qry) + (response_type, response, gls_time) = self._query(po_name, data, uid, qry) if response_type == QuerySuccessful: result_dict = { - 'response': response['response'], - 'version': response['version'], - 'model': po_name, - 'uuid': uid + "response": response["response"], + "version": response["version"], + "model": po_name, + "uuid": uid, } self.write(result_dict) self.finish() - return (gls_time, response['response']) + return (gls_time, response["response"]) else: if response_type == UnknownURI: - self.error_out(404, 'UnknownURI', - info=('No query object has been registered' - f' with the name "{po_name}"')) + self.error_out( + 404, + "UnknownURI", + info=( + "No query object has been registered" + f' with the name "{po_name}"' + ), + ) elif response_type == QueryError: - self.error_out(400, 'QueryError', info=response) + self.error_out(400, "QueryError", info=response) else: - self.error_out(500, 'Error querying GLS', info=response) + self.error_out(500, "Error querying GLS", info=response) return (None, None) def _sanitize_request_data(self, data): if not isinstance(data, dict): - msg = 'Input data must be a dictionary' + msg = "Input data must be a dictionary" self.logger.log(logging.CRITICAL, msg) raise RuntimeError(msg) @@ -121,8 +124,7 @@ def _sanitize_request_data(self, data): raise RuntimeError(msg) def _process_query(self, endpoint_name, start): - self.logger.log(logging.DEBUG, - f'Processing query {endpoint_name}...') + self.logger.log(logging.DEBUG, f"Processing query {endpoint_name}...") try: self._add_CORS_header() @@ -130,7 +132,7 @@ def _process_query(self, endpoint_name, start): self.request.body = {} # extract request data explicitly for caching purpose - request_json = self.request.body.decode('utf-8') + request_json = self.request.body.decode("utf-8") # Sanitize input data data = self._sanitize_request_data(json.loads(request_json)) @@ -141,29 +143,28 @@ def _process_query(self, endpoint_name, start): return try: - (po_name, _) = self._get_actual_model( - endpoint_name) + (po_name, _) = self._get_actual_model(endpoint_name) # po_name is None if self.python_service.ps.query_objects.get( # endpoint_name) is None if not po_name: self.error_out( - 404, - 'UnknownURI', - info=f'Endpoint "{endpoint_name}" does not exist') + 404, "UnknownURI", info=f'Endpoint "{endpoint_name}" does not exist' + ) return po_obj = self.python_service.ps.query_objects.get(po_name) if not po_obj: - self.error_out(404, 'UnknownURI', - info=f'Endpoint "{po_name}" does not exist') + self.error_out( + 404, "UnknownURI", info=f'Endpoint "{po_name}" does not exist' + ) return if po_name != endpoint_name: self.logger.log( - logging.INFO, - f'Querying actual model: po_name={po_name}') + logging.INFO, f"Querying actual model: po_name={po_name}" + ) uid = _get_uuid() @@ -179,8 +180,8 @@ def _process_query(self, endpoint_name, start): except Exception as e: self.logger.log(logging.ERROR, str(e)) - err_msg = format_exception(e, 'process query') - self.error_out(500, 'Error processing query', info=err_msg) + err_msg = format_exception(e, "process query") + self.error_out(500, "Error processing query", info=err_msg) return def _get_actual_model(self, endpoint_name): @@ -188,24 +189,24 @@ def _get_actual_model(self, endpoint_name): all_endpoint_names = [] while True: - endpoint_info = self.python_service.ps.query_objects.get( - endpoint_name) + endpoint_info = self.python_service.ps.query_objects.get(endpoint_name) if not endpoint_info: return [None, None] all_endpoint_names.append(endpoint_name) - endpoint_type = endpoint_info.get('type', 'model') + endpoint_type = endpoint_info.get("type", "model") - if endpoint_type == 'alias': - endpoint_name = endpoint_info['endpoint_obj'] - elif endpoint_type == 'model': + if endpoint_type == "alias": + endpoint_name = endpoint_info["endpoint_obj"] + elif endpoint_type == "model": break else: self.error_out( 500, - 'Unknown endpoint type', - info=f'Endpoint type "{endpoint_type}" does not exist') + "Unknown endpoint type", + info=f'Endpoint type "{endpoint_type}" does not exist', + ) return return (endpoint_name, all_endpoint_names) @@ -222,8 +223,7 @@ def get(self, endpoint_name): @gen.coroutine def post(self, endpoint_name): - self.logger.log(logging.DEBUG, - f'Processing POST for /query/{endpoint_name}...') + self.logger.log(logging.DEBUG, f"Processing POST for /query/{endpoint_name}...") if self.should_fail_with_not_authorized(): self.fail_with_not_authorized() diff --git a/tabpy/tabpy_server/handlers/service_info_handler.py b/tabpy/tabpy_server/handlers/service_info_handler.py index 6341c149..6b7060fb 100644 --- a/tabpy/tabpy_server/handlers/service_info_handler.py +++ b/tabpy/tabpy_server/handlers/service_info_handler.py @@ -13,11 +13,10 @@ def get(self): # supported API versions and required features self._add_CORS_header() info = {} - info['description'] = self.tabpy_state.get_description() - info['creation_time'] = self.tabpy_state.creation_time - info['state_path'] = self.settings[SettingsParameters.StateFilePath] - info['server_version'] =\ - self.settings[SettingsParameters.ServerVersion] - info['name'] = self.tabpy_state.name - info['versions'] = self.settings[SettingsParameters.ApiVersions] + info["description"] = self.tabpy_state.get_description() + info["creation_time"] = self.tabpy_state.creation_time + info["state_path"] = self.settings[SettingsParameters.StateFilePath] + info["server_version"] = self.settings[SettingsParameters.ServerVersion] + info["name"] = self.tabpy_state.name + info["versions"] = self.settings[SettingsParameters.ApiVersions] self.write(json.dumps(info)) diff --git a/tabpy/tabpy_server/handlers/status_handler.py b/tabpy/tabpy_server/handlers/status_handler.py index 3b2af815..2f743b3f 100644 --- a/tabpy/tabpy_server/handlers/status_handler.py +++ b/tabpy/tabpy_server/handlers/status_handler.py @@ -17,14 +17,13 @@ def get(self): status_dict = {} for k, v in self.python_service.ps.query_objects.items(): status_dict[k] = { - 'version': v['version'], - 'type': v['type'], - 'status': v['status'], - 'last_error': v['last_error']} + "version": v["version"], + "type": v["type"], + "status": v["status"], + "last_error": v["last_error"], + } - self.logger.log( - logging.DEBUG, - f'Found models: {status_dict}') + self.logger.log(logging.DEBUG, f"Found models: {status_dict}") self.write(json.dumps(status_dict)) self.finish() return diff --git a/tabpy/tabpy_server/handlers/upload_destination_handler.py b/tabpy/tabpy_server/handlers/upload_destination_handler.py index 5211b1e6..5d72f48b 100644 --- a/tabpy/tabpy_server/handlers/upload_destination_handler.py +++ b/tabpy/tabpy_server/handlers/upload_destination_handler.py @@ -4,7 +4,7 @@ import os -_QUERY_OBJECT_STAGING_FOLDER = 'staging' +_QUERY_OBJECT_STAGING_FOLDER = "staging" class UploadDestinationHandler(ManagementHandler): diff --git a/tabpy/tabpy_server/handlers/util.py b/tabpy/tabpy_server/handlers/util.py index e835d7fc..3c959a40 100755 --- a/tabpy/tabpy_server/handlers/util.py +++ b/tabpy/tabpy_server/handlers/util.py @@ -5,7 +5,7 @@ def hash_password(username, pwd): - ''' + """ Hashes password using PKDBF2 method: hash = PKDBF2('sha512', pwd, salt=username, 10000) @@ -25,11 +25,10 @@ def hash_password(username, pwd): str Sting representation (hexidecimal) for PBKDF2 hash for the password. - ''' - salt = f'_$salt@tabpy:{username.lower()}$_' + """ + salt = f"_$salt@tabpy:{username.lower()}$_" - hash = pbkdf2_hmac(hash_name='sha512', - password=pwd.encode(), - salt=salt.encode(), - iterations=10000) + hash = pbkdf2_hmac( + hash_name="sha512", password=pwd.encode(), salt=salt.encode(), iterations=10000 + ) return binascii.hexlify(hash).decode() diff --git a/tabpy/tabpy_server/management/state.py b/tabpy/tabpy_server/management/state.py index 9dbf6d77..796a9bae 100644 --- a/tabpy/tabpy_server/management/state.py +++ b/tabpy/tabpy_server/management/state.py @@ -13,24 +13,25 @@ logger = logging.getLogger(__name__) # State File Config Section Names -_DEPLOYMENT_SECTION_NAME = 'Query Objects Service Versions' -_QUERY_OBJECT_DOCSTRING = 'Query Objects Docstrings' -_SERVICE_INFO_SECTION_NAME = 'Service Info' -_META_SECTION_NAME = 'Meta' +_DEPLOYMENT_SECTION_NAME = "Query Objects Service Versions" +_QUERY_OBJECT_DOCSTRING = "Query Objects Docstrings" +_SERVICE_INFO_SECTION_NAME = "Service Info" +_META_SECTION_NAME = "Meta" # Directory Names -_QUERY_OBJECT_DIR = 'query_objects' +_QUERY_OBJECT_DIR = "query_objects" -''' +""" Lock to change the TabPy State. -''' +""" _PS_STATE_LOCK = Lock() def state_lock(func): - ''' + """ Mutex for changing PS state - ''' + """ + def wrapper(self, *args, **kwargs): try: _PS_STATE_LOCK.acquire() @@ -38,34 +39,33 @@ def wrapper(self, *args, **kwargs): finally: # ALWAYS RELEASE LOCK _PS_STATE_LOCK.release() + return wrapper def _get_root_path(state_path): - if state_path[-1] != '/': - return state_path + '/' + if state_path[-1] != "/": + return state_path + "/" else: return state_path def get_query_object_path(state_file_path, name, version): - ''' + """ Returns the query object path If the version is None, a path without the version will be returned. - ''' + """ root_path = _get_root_path(state_file_path) if version is not None: - full_path = root_path + \ - '/'.join([_QUERY_OBJECT_DIR, name, str(version)]) + full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name, str(version)]) else: - full_path = root_path + \ - '/'.join([_QUERY_OBJECT_DIR, name]) + full_path = root_path + "/".join([_QUERY_OBJECT_DIR, name]) return full_path class TabPyState: - ''' + """ The TabPy state object that stores attributes about this TabPy and perform GET/SET on these attributes. @@ -79,19 +79,18 @@ class TabPyState: When the state object is initialized, the state is saved as a ConfigParser. There is a config to any attribute. - ''' + """ + def __init__(self, settings, config=None): self.settings = settings self.set_config(config, _update=False) @state_lock - def set_config(self, config, - logger=logging.getLogger(__name__), - _update=True): - ''' + def set_config(self, config, logger=logging.getLogger(__name__), _update=True): + """ Set the local ConfigParser manually. This new ConfigParser will be used as current state. - ''' + """ if not isinstance(config, ConfigParser): raise ValueError("Invalid config") self.config = config @@ -99,7 +98,7 @@ def set_config(self, config, self._write_state(logger) def get_endpoints(self, name=None): - ''' + """ Return a dictionary of endpoints Parameters @@ -120,39 +119,49 @@ def get_endpoints(self, name=None): - type - target - ''' + """ endpoints = {} try: - endpoint_names = self._get_config_value( - _DEPLOYMENT_SECTION_NAME, name) + endpoint_names = self._get_config_value(_DEPLOYMENT_SECTION_NAME, name) except Exception as e: - logger.error(f'error in get_endpoints: {str(e)}') + logger.error(f"error in get_endpoints: {str(e)}") return {} if name: endpoint_info = json.loads(endpoint_names) docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name) - endpoint_info['docstring'] = str( - bytes(docstring, "utf-8").decode('unicode_escape')) + endpoint_info["docstring"] = str( + bytes(docstring, "utf-8").decode("unicode_escape") + ) endpoints = {name: endpoint_info} else: for endpoint_name in endpoint_names: - endpoint_info = json.loads(self._get_config_value( - _DEPLOYMENT_SECTION_NAME, endpoint_name)) - docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, - endpoint_name, True, '') - endpoint_info['docstring'] = str( - bytes(docstring, "utf-8").decode('unicode_escape')) + endpoint_info = json.loads( + self._get_config_value(_DEPLOYMENT_SECTION_NAME, endpoint_name) + ) + docstring = self._get_config_value( + _QUERY_OBJECT_DOCSTRING, endpoint_name, True, "" + ) + endpoint_info["docstring"] = str( + bytes(docstring, "utf-8").decode("unicode_escape") + ) endpoints[endpoint_name] = endpoint_info - logger.debug(f'Collected endpoints: {endpoints}') + logger.debug(f"Collected endpoints: {endpoints}") return endpoints @state_lock - def add_endpoint(self, name, description=None, - docstring=None, endpoint_type=None, - methods=None, target=None, dependencies=None, - schema=None): - ''' + def add_endpoint( + self, + name, + description=None, + docstring=None, + endpoint_type=None, + methods=None, + target=None, + dependencies=None, + schema=None, + ): + """ Add a new endpoint to the TabPy. Parameters @@ -172,23 +181,21 @@ def add_endpoint(self, name, description=None, The version of this endpoint will be set to 1 since it is a new endpoint. - ''' + """ try: endpoints = self.get_endpoints() - if name is None or not isinstance( - name, str) or len(name) == 0: - raise ValueError( - "name of the endpoint must be a valid string.") + if name is None or not isinstance(name, str) or len(name) == 0: + raise ValueError("name of the endpoint must be a valid string.") elif name in endpoints: - raise ValueError(f'endpoint {name} already exists.') + raise ValueError(f"endpoint {name} already exists.") if description and not isinstance(description, str): raise ValueError("description must be a string.") elif not description: - description = '' + description = "" if docstring and not isinstance(docstring, str): raise ValueError("docstring must be a string.") elif not docstring: - docstring = '-- no docstring found in query function --' + docstring = "-- no docstring found in query function --" if not endpoint_type or not isinstance(endpoint_type, str): raise ValueError("endpoint type must be a string.") if dependencies and not isinstance(dependencies, list): @@ -200,48 +207,61 @@ def add_endpoint(self, name, description=None, elif target and target not in endpoints: raise ValueError("target endpoint is not valid.") - endpoint_info = {"description": description, - "docstring": docstring, - "type": endpoint_type, - "version": 1, - "dependencies": dependencies, - "target": target, - "creation_time": int(time()), - "last_modified_time": int(time()), - "schema": schema} + endpoint_info = { + "description": description, + "docstring": docstring, + "type": endpoint_type, + "version": 1, + "dependencies": dependencies, + "target": target, + "creation_time": int(time()), + "last_modified_time": int(time()), + "schema": schema, + } endpoints[name] = endpoint_info self._add_update_endpoints_config(endpoints) except Exception as e: - logger.error(f'Error in add_endpoint: {e}') + logger.error(f"Error in add_endpoint: {e}") raise def _add_update_endpoints_config(self, endpoints): # save the endpoint info to config - dstring = '' + dstring = "" for endpoint_name in endpoints: try: info = endpoints[endpoint_name] - dstring = str(bytes(info['docstring'], "utf-8").decode( - 'unicode_escape')) - self._set_config_value(_QUERY_OBJECT_DOCSTRING, - endpoint_name, - dstring, - _update_revision=False) - del info['docstring'] - self._set_config_value(_DEPLOYMENT_SECTION_NAME, - endpoint_name, json.dumps(info)) + dstring = str( + bytes(info["docstring"], "utf-8").decode("unicode_escape") + ) + self._set_config_value( + _QUERY_OBJECT_DOCSTRING, + endpoint_name, + dstring, + _update_revision=False, + ) + del info["docstring"] + self._set_config_value( + _DEPLOYMENT_SECTION_NAME, endpoint_name, json.dumps(info) + ) except Exception as e: - logger.error(f'Unable to write endpoints config: {e}') + logger.error(f"Unable to write endpoints config: {e}") raise @state_lock - def update_endpoint(self, name, description=None, - docstring=None, endpoint_type=None, - version=None, methods=None, - target=None, dependencies=None, - schema=None): - ''' + def update_endpoint( + self, + name, + description=None, + docstring=None, + endpoint_type=None, + version=None, + methods=None, + target=None, + dependencies=None, + schema=None, + ): + """ Update an existing endpoint on the TabPy. Parameters @@ -265,37 +285,37 @@ def update_endpoint(self, name, description=None, For those parameters that are not specified, those values will not get changed. - ''' + """ try: endpoints = self.get_endpoints() if not name or not isinstance(name, str): raise ValueError("name of the endpoint must be string.") elif name not in endpoints: - raise ValueError(f'endpoint {name} does not exist.') + raise ValueError(f"endpoint {name} does not exist.") endpoint_info = endpoints[name] if description and not isinstance(description, str): raise ValueError("description must be a string.") elif not description: - description = endpoint_info['description'] + description = endpoint_info["description"] if docstring and not isinstance(docstring, str): raise ValueError("docstring must be a string.") elif not docstring: - docstring = endpoint_info['docstring'] + docstring = endpoint_info["docstring"] if endpoint_type and not isinstance(endpoint_type, str): raise ValueError("endpoint type must be a string.") elif not endpoint_type: - endpoint_type = endpoint_info['type'] + endpoint_type = endpoint_info["type"] if version and not isinstance(version, int): raise ValueError("version must be an int.") elif not version: - version = endpoint_info['version'] + version = endpoint_info["version"] if dependencies and not isinstance(dependencies, list): raise ValueError("dependencies must be a list.") elif not dependencies: - if 'dependencies' in endpoint_info: - dependencies = endpoint_info['dependencies'] + if "dependencies" in endpoint_info: + dependencies = endpoint_info["dependencies"] else: dependencies = [] if target and not isinstance(target, str): @@ -303,26 +323,28 @@ def update_endpoint(self, name, description=None, elif target and target not in endpoints: raise ValueError("target endpoint is not valid.") elif not target: - target = endpoint_info['target'] - endpoint_info = {'description': description, - 'docstring': docstring, - 'type': endpoint_type, - 'version': version, - 'dependencies': dependencies, - 'target': target, - 'creation_time': endpoint_info['creation_time'], - 'last_modified_time': int(time()), - 'schema': schema} + target = endpoint_info["target"] + endpoint_info = { + "description": description, + "docstring": docstring, + "type": endpoint_type, + "version": version, + "dependencies": dependencies, + "target": target, + "creation_time": endpoint_info["creation_time"], + "last_modified_time": int(time()), + "schema": schema, + } endpoints[name] = endpoint_info self._add_update_endpoints_config(endpoints) except Exception as e: - logger.error(f'Error in update_endpoint: {e}') + logger.error(f"Error in update_endpoint: {e}") raise @state_lock def delete_endpoint(self, name): - ''' + """ Delete an existing endpoint on the TabPy Parameters @@ -338,12 +360,12 @@ def delete_endpoint(self, name): Cannot delete this endpoint if other endpoints are currently depending on this endpoint. - ''' - if not name or name == '': + """ + if not name or name == "": raise ValueError("Name of the endpoint must be a valid string.") endpoints = self.get_endpoints() if name not in endpoints: - raise ValueError(f'Endpoint {name} does not exist.') + raise ValueError(f"Endpoint {name} does not exist.") endpoint_to_delete = endpoints[name] @@ -351,167 +373,178 @@ def delete_endpoint(self, name): deps = set() for endpoint_name in endpoints: if endpoint_name != name: - deps_list = endpoints[endpoint_name].get('dependencies', []) + deps_list = endpoints[endpoint_name].get("dependencies", []) if name in deps_list: deps.add(endpoint_name) # check if other endpoints are depending on this endpoint if len(deps) > 0: raise ValueError( - f'Cannot remove endpoint {name}, it is currently ' - f'used by {list(deps)} endpoints.') + f"Cannot remove endpoint {name}, it is currently " + f"used by {list(deps)} endpoints." + ) del endpoints[name] # delete the endpoint from state try: - self._remove_config_option(_QUERY_OBJECT_DOCSTRING, name, - _update_revision=False) + self._remove_config_option( + _QUERY_OBJECT_DOCSTRING, name, _update_revision=False + ) self._remove_config_option(_DEPLOYMENT_SECTION_NAME, name) return endpoint_to_delete except Exception as e: - logger.error(f'Unable to delete endpoint {e}') - raise ValueError(f'Unable to delete endpoint: {e}') + logger.error(f"Unable to delete endpoint {e}") + raise ValueError(f"Unable to delete endpoint: {e}") @property def name(self): - ''' + """ Returns the name of the TabPy service. - ''' + """ name = None try: - name = self._get_config_value(_SERVICE_INFO_SECTION_NAME, 'Name') + name = self._get_config_value(_SERVICE_INFO_SECTION_NAME, "Name") except Exception as e: - logger.error(f'Unable to get name: {e}') + logger.error(f"Unable to get name: {e}") return name @property def creation_time(self): - ''' + """ Returns the creation time of the TabPy service. - ''' + """ creation_time = 0 try: creation_time = self._get_config_value( - _SERVICE_INFO_SECTION_NAME, 'Creation Time') + _SERVICE_INFO_SECTION_NAME, "Creation Time" + ) except Exception as e: - logger.error(f'Unable to get name: {e}') + logger.error(f"Unable to get name: {e}") return creation_time @state_lock def set_name(self, name): - ''' + """ Set the name of this TabPy service. Parameters ---------- name : str Name of TabPy service. - ''' + """ if not isinstance(name, str): raise ValueError("name must be a string.") try: - self._set_config_value(_SERVICE_INFO_SECTION_NAME, 'Name', name) + self._set_config_value(_SERVICE_INFO_SECTION_NAME, "Name", name) except Exception as e: - logger.error(f'Unable to set name: {e}') + logger.error(f"Unable to set name: {e}") def get_description(self): - ''' + """ Returns the description of the TabPy service. - ''' + """ description = None try: description = self._get_config_value( - _SERVICE_INFO_SECTION_NAME, 'Description') + _SERVICE_INFO_SECTION_NAME, "Description" + ) except Exception as e: - logger.error(f'Unable to get description: {e}') + logger.error(f"Unable to get description: {e}") return description @state_lock def set_description(self, description): - ''' + """ Set the description of this TabPy service. Parameters ---------- description : str Description of TabPy service. - ''' + """ if not isinstance(description, str): raise ValueError("Description must be a string.") try: self._set_config_value( - _SERVICE_INFO_SECTION_NAME, 'Description', description) + _SERVICE_INFO_SECTION_NAME, "Description", description + ) except Exception as e: - logger.error(f'Unable to set description: {e}') + logger.error(f"Unable to set description: {e}") def get_revision_number(self): - ''' + """ Returns the revision number of this TabPy service. - ''' + """ rev = -1 try: - rev = int(self._get_config_value( - _META_SECTION_NAME, 'Revision Number')) + rev = int(self._get_config_value(_META_SECTION_NAME, "Revision Number")) except Exception as e: - logger.error(f'Unable to get revision number: {e}') + logger.error(f"Unable to get revision number: {e}") return rev def get_access_control_allow_origin(self): - ''' + """ Returns Access-Control-Allow-Origin of this TabPy service. - ''' - _cors_origin = '' + """ + _cors_origin = "" try: - logger.debug("Collecting Access-Control-Allow-Origin from " - "state file...") + logger.debug("Collecting Access-Control-Allow-Origin from " "state file...") _cors_origin = self._get_config_value( - 'Service Info', 'Access-Control-Allow-Origin') + "Service Info", "Access-Control-Allow-Origin" + ) except Exception as e: logger.error(e) pass return _cors_origin def get_access_control_allow_headers(self): - ''' + """ Returns Access-Control-Allow-Headers of this TabPy service. - ''' - _cors_headers = '' + """ + _cors_headers = "" try: _cors_headers = self._get_config_value( - 'Service Info', 'Access-Control-Allow-Headers') + "Service Info", "Access-Control-Allow-Headers" + ) except Exception: pass return _cors_headers def get_access_control_allow_methods(self): - ''' + """ Returns Access-Control-Allow-Methods of this TabPy service. - ''' - _cors_methods = '' + """ + _cors_methods = "" try: _cors_methods = self._get_config_value( - 'Service Info', 'Access-Control-Allow-Methods') + "Service Info", "Access-Control-Allow-Methods" + ) except Exception: pass return _cors_methods def _set_revision_number(self, revision_number): - ''' + """ Set the revision number of this TabPy service. - ''' + """ if not isinstance(revision_number, int): raise ValueError("revision number must be an int.") try: - self._set_config_value(_META_SECTION_NAME, - 'Revision Number', revision_number) + self._set_config_value( + _META_SECTION_NAME, "Revision Number", revision_number + ) except Exception as e: - logger.error(f'Unable to set revision number: {e}') - - def _remove_config_option(self, section_name, option_name, - logger=logging.getLogger(__name__), - _update_revision=True): + logger.error(f"Unable to set revision number: {e}") + + def _remove_config_option( + self, + section_name, + option_name, + logger=logging.getLogger(__name__), + _update_revision=True, + ): if not self.config: raise ValueError("State configuration not yet loaded.") self.config.remove_option(section_name, option_name) @@ -528,18 +561,22 @@ def _has_config_value(self, section_name, option_name): def _increase_revision_number(self): if not self.config: raise ValueError("State configuration not yet loaded.") - cur_rev = int(self.config.get(_META_SECTION_NAME, 'Revision Number')) - self.config.set(_META_SECTION_NAME, 'Revision Number', - str(cur_rev + 1)) - - def _set_config_value(self, section_name, option_name, option_value, - logger=logging.getLogger(__name__), - _update_revision=True): + cur_rev = int(self.config.get(_META_SECTION_NAME, "Revision Number")) + self.config.set(_META_SECTION_NAME, "Revision Number", str(cur_rev + 1)) + + def _set_config_value( + self, + section_name, + option_name, + option_value, + logger=logging.getLogger(__name__), + _update_revision=True, + ): if not self.config: raise ValueError("State configuration not yet loaded.") if not self.config.has_section(section_name): - logger.log(logging.DEBUG, f'Adding config section {section_name}') + logger.log(logging.DEBUG, f"Adding config section {section_name}") self.config.add_section(section_name) self.config.set(section_name, option_name, option_value) @@ -553,8 +590,9 @@ def _get_config_items(self, section_name): raise ValueError("State configuration not yet loaded.") return self.config.items(section_name) - def _get_config_value(self, section_name, option_name, optional=False, - default_value=None): + def _get_config_value( + self, section_name, option_name, optional=False, default_value=None + ): if not self.config: raise ValueError("State configuration not yet loaded.") @@ -567,12 +605,13 @@ def _get_config_value(self, section_name, option_name, optional=False, return default_value else: raise ValueError( - f'Cannot find option name {option_name} ' - f'under section {section_name}') + f"Cannot find option name {option_name} " + f"under section {section_name}" + ) def _write_state(self, logger=logging.getLogger(__name__)): - ''' + """ Write state (ConfigParser) to Consul - ''' - logger.log(logging.INFO, 'Writing state to config') + """ + logger.log(logging.INFO, "Writing state to config") write_state_config(self.config, self.settings, logger=logger) diff --git a/tabpy/tabpy_server/management/util.py b/tabpy/tabpy_server/management/util.py index 7bc21244..7590461b 100644 --- a/tabpy/tabpy_server/management/util.py +++ b/tabpy/tabpy_server/management/util.py @@ -1,5 +1,6 @@ import logging import os + try: from ConfigParser import ConfigParser as _ConfigParser except ImportError: @@ -13,24 +14,24 @@ def write_state_config(state, settings, logger=logging.getLogger(__name__)): if SettingsParameters.StateFilePath in settings: state_path = settings[SettingsParameters.StateFilePath] else: - msg = f'{ConfigParameters.TABPY_STATE_PATH} is not set' + msg = f"{ConfigParameters.TABPY_STATE_PATH} is not set" logger.log(logging.CRITICAL, msg) raise ValueError(msg) - logger.log(logging.DEBUG, f'State path is {state_path}') - state_key = os.path.join(state_path, 'state.ini') + logger.log(logging.DEBUG, f"State path is {state_path}") + state_key = os.path.join(state_path, "state.ini") tmp_state_file = state_key - with open(tmp_state_file, 'w') as f: + with open(tmp_state_file, "w") as f: state.write(f) def _get_state_from_file(state_path, logger=logging.getLogger(__name__)): - state_key = os.path.join(state_path, 'state.ini') + state_key = os.path.join(state_path, "state.ini") tmp_state_file = state_key if not os.path.exists(tmp_state_file): - msg = f'Missing config file at {tmp_state_file}' + msg = f"Missing config file at {tmp_state_file}" logger.log(logging.CRITICAL, msg) raise ValueError(msg) @@ -38,11 +39,9 @@ def _get_state_from_file(state_path, logger=logging.getLogger(__name__)): config.optionxform = str config.read(tmp_state_file) - if not config.has_section('Service Info'): - msg = ('Config error: Expected [Service Info] section in ' - f'{tmp_state_file}') + if not config.has_section("Service Info"): + msg = "Config error: Expected [Service Info] section in " f"{tmp_state_file}" logger.log(logging.CRITICAL, msg) raise ValueError(msg) return config - diff --git a/tabpy/tabpy_server/psws/callbacks.py b/tabpy/tabpy_server/psws/callbacks.py index d6c38b4f..6949f47b 100644 --- a/tabpy/tabpy_server/psws/callbacks.py +++ b/tabpy/tabpy_server/psws/callbacks.py @@ -1,12 +1,15 @@ import logging import sys from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters -from tabpy.tabpy_server.common.messages\ - import (LoadObject, DeleteObjects, ListObjects, ObjectList) +from tabpy.tabpy_server.common.messages import ( + LoadObject, + DeleteObjects, + ListObjects, + ObjectList, +) from tabpy.tabpy_server.common.endpoint_file_mgr import cleanup_endpoint_files from tabpy.tabpy_server.common.util import format_exception -from tabpy.tabpy_server.management.state\ - import TabPyState, get_query_object_path +from tabpy.tabpy_server.management.state import TabPyState, get_query_object_path from tabpy.tabpy_server.management import util from time import sleep from tornado import gen @@ -16,21 +19,20 @@ def wait_for_endpoint_loaded(python_service, object_uri): - ''' + """ This method waits for the object to be loaded. - ''' - logger.info('Waiting for object to be loaded...') + """ + logger.info("Waiting for object to be loaded...") while True: msg = ListObjects() list_object_msg = python_service.manage_request(msg) if not isinstance(list_object_msg, ObjectList): - logger.error( - f'Error loading endpoint {object_uri}: {list_object_msg}') + logger.error(f"Error loading endpoint {object_uri}: {list_object_msg}") return - for (uri, info) in (list_object_msg.objects.items()): + for (uri, info) in list_object_msg.objects.items(): if uri == object_uri: - if info['status'] != 'LoadInProgress': + if info["status"] != "LoadInProgress": logger.info(f'Object load status: {info["status"]}') return @@ -41,54 +43,55 @@ def wait_for_endpoint_loaded(python_service, object_uri): def init_ps_server(settings, tabpy_state): logger.info("Initializing TabPy Server...") existing_pos = tabpy_state.get_endpoints() - for (object_name, obj_info) in (existing_pos.items()): + for (object_name, obj_info) in existing_pos.items(): try: - object_version = obj_info['version'] + object_version = obj_info["version"] get_query_object_path( - settings[SettingsParameters.StateFilePath], - object_name, object_version) + settings[SettingsParameters.StateFilePath], object_name, object_version + ) except Exception as e: logger.error( - f'Exception encounted when downloading object: {object_name}' - f', error: {e}') + f"Exception encounted when downloading object: {object_name}" + f", error: {e}" + ) @gen.coroutine def init_model_evaluator(settings, tabpy_state, python_service): - ''' + """ This will go through all models that the service currently have and initialize them. - ''' + """ logger.info("Initializing models...") existing_pos = tabpy_state.get_endpoints() - for (object_name, obj_info) in (existing_pos.items()): - object_version = obj_info['version'] - object_type = obj_info['type'] + for (object_name, obj_info) in existing_pos.items(): + object_version = obj_info["version"] + object_type = obj_info["type"] object_path = get_query_object_path( - settings[SettingsParameters.StateFilePath], - object_name, object_version) + settings[SettingsParameters.StateFilePath], object_name, object_version + ) logger.info( - f'Load endpoint: {object_name}, ' - f'version: {object_version}, ' - f'type: {object_type}') - if object_type == 'alias': - msg = LoadObject(object_name, obj_info['target'], - object_version, False, 'alias') + f"Load endpoint: {object_name}, " + f"version: {object_version}, " + f"type: {object_type}" + ) + if object_type == "alias": + msg = LoadObject( + object_name, obj_info["target"], object_version, False, "alias" + ) else: local_path = object_path - msg = LoadObject(object_name, local_path, object_version, - False, object_type) + msg = LoadObject( + object_name, local_path, object_version, False, object_type + ) python_service.manage_request(msg) -def _get_latest_service_state(settings, - tabpy_state, - new_ps_state, - python_service): - ''' +def _get_latest_service_state(settings, tabpy_state, new_ps_state, python_service): + """ Update the endpoints from the latest remote state file. Returns @@ -96,9 +99,9 @@ def _get_latest_service_state(settings, (has_changes, endpoint_diff): has_changes: True or False endpoint_diff: Summary of what has changed, one entry for each changes - ''' + """ # Shortcut when nothing is changed - changes = {'endpoints': {}} + changes = {"endpoints": {}} # update endpoints new_endpoints = new_ps_state.get_endpoints() @@ -106,72 +109,84 @@ def _get_latest_service_state(settings, current_endpoints = python_service.ps.query_objects for (endpoint_name, endpoint_info) in new_endpoints.items(): existing_endpoint = current_endpoints.get(endpoint_name) - if (existing_endpoint is None) or \ - endpoint_info['version'] != existing_endpoint['version']: + if (existing_endpoint is None) or endpoint_info["version"] != existing_endpoint[ + "version" + ]: # Either a new endpoint or new endpoint version path_to_new_version = get_query_object_path( settings[SettingsParameters.StateFilePath], - endpoint_name, endpoint_info['version']) - endpoint_type = endpoint_info.get('type', 'model') - diff[endpoint_name] = (endpoint_type, endpoint_info['version'], - path_to_new_version) + endpoint_name, + endpoint_info["version"], + ) + endpoint_type = endpoint_info.get("type", "model") + diff[endpoint_name] = ( + endpoint_type, + endpoint_info["version"], + path_to_new_version, + ) # add removed models too for (endpoint_name, endpoint_info) in current_endpoints.items(): if endpoint_name not in new_endpoints.keys(): - endpoint_type = current_endpoints[endpoint_name].get( - 'type', 'model') + endpoint_type = current_endpoints[endpoint_name].get("type", "model") diff[endpoint_name] = (endpoint_type, None, None) if diff: - changes['endpoints'] = diff + changes["endpoints"] = diff tabpy_state = new_ps_state return (True, changes) @gen.coroutine -def on_state_change(settings, tabpy_state, python_service, - logger=logging.getLogger(__name__)): +def on_state_change( + settings, tabpy_state, python_service, logger=logging.getLogger(__name__) +): try: logger.log(logging.INFO, "Loading state from state file") config = util._get_state_from_file( - settings[SettingsParameters.StateFilePath], - logger=logger) + settings[SettingsParameters.StateFilePath], logger=logger + ) new_ps_state = TabPyState(config=config, settings=settings) - (has_changes, changes) = _get_latest_service_state(settings, - tabpy_state, - new_ps_state, - python_service) + (has_changes, changes) = _get_latest_service_state( + settings, tabpy_state, new_ps_state, python_service + ) if not has_changes: logger.info("Nothing changed, return.") return new_endpoints = new_ps_state.get_endpoints() - for object_name in changes['endpoints']: - (object_type, object_version, object_path) = changes['endpoints'][ - object_name] + for object_name in changes["endpoints"]: + (object_type, object_version, object_path) = changes["endpoints"][ + object_name + ] if not object_path and not object_version: # removal - logger.info(f'Removing object: URI={object_name}') + logger.info(f"Removing object: URI={object_name}") python_service.manage_request(DeleteObjects([object_name])) - cleanup_endpoint_files(object_name, - settings[SettingsParameters.UploadDir], - logger=logger) + cleanup_endpoint_files( + object_name, settings[SettingsParameters.UploadDir], logger=logger + ) else: endpoint_info = new_endpoints[object_name] is_update = object_version > 1 - if object_type == 'alias': - msg = LoadObject(object_name, endpoint_info['target'], - object_version, is_update, 'alias') + if object_type == "alias": + msg = LoadObject( + object_name, + endpoint_info["target"], + object_version, + is_update, + "alias", + ) else: local_path = object_path - msg = LoadObject(object_name, local_path, object_version, - is_update, object_type) + msg = LoadObject( + object_name, local_path, object_version, is_update, object_type + ) python_service.manage_request(msg) wait_for_endpoint_loaded(python_service, object_name) @@ -179,11 +194,14 @@ def on_state_change(settings, tabpy_state, python_service, # cleanup old version of endpoint files if object_version > 2: cleanup_endpoint_files( - object_name, settings[SettingsParameters.UploadDir], + object_name, + settings[SettingsParameters.UploadDir], logger=logger, - retain_versions=[object_version, object_version - 1]) + retain_versions=[object_version, object_version - 1], + ) except Exception as e: - err_msg = format_exception(e, 'on_state_change') - logger.log(logging.ERROR, - f'Error submitting update model request: error={err_msg}') + err_msg = format_exception(e, "on_state_change") + logger.log( + logging.ERROR, f"Error submitting update model request: error={err_msg}" + ) diff --git a/tabpy/tabpy_server/psws/python_service.py b/tabpy/tabpy_server/psws/python_service.py index 8e98bf77..ba506ce8 100644 --- a/tabpy/tabpy_server/psws/python_service.py +++ b/tabpy/tabpy_server/psws/python_service.py @@ -6,10 +6,23 @@ from tabpy.tabpy_tools.query_object import QueryObject from tabpy.tabpy_server.common.util import format_exception from tabpy.tabpy_server.common.messages import ( - LoadObject, DeleteObjects, FlushObjects, CountObjects, ListObjects, - UnknownMessage, LoadFailed, ObjectsDeleted, ObjectsFlushed, QueryFailed, - QuerySuccessful, UnknownURI, DownloadSkipped, LoadInProgress, ObjectCount, - ObjectList) + LoadObject, + DeleteObjects, + FlushObjects, + CountObjects, + ListObjects, + UnknownMessage, + LoadFailed, + ObjectsDeleted, + ObjectsFlushed, + QueryFailed, + QuerySuccessful, + UnknownURI, + DownloadSkipped, + LoadInProgress, + ObjectCount, + ObjectList, +) logger = logging.getLogger(__name__) @@ -20,12 +33,13 @@ class PythonServiceHandler: A wrapper around PythonService object that receives requests and calls the corresponding methods. """ + def __init__(self, ps): self.ps = ps def manage_request(self, msg): try: - logger.debug(f'Received request {type(msg).__name__}') + logger.debug(f"Received request {type(msg).__name__}") if isinstance(msg, LoadObject): response = self.ps.load_object(*msg) elif isinstance(msg, DeleteObjects): @@ -39,14 +53,14 @@ def manage_request(self, msg): else: response = UnknownMessage(msg) - logger.debug(f'Returning response {response}') + logger.debug(f"Returning response {response}") return response except Exception as e: logger.exception(e) msg = e - if hasattr(e, 'message'): + if hasattr(e, "message"): msg = e.message - logger.error(f'Error processing request: {msg}') + logger.error(f"Error processing request: {msg}") return UnknownMessage(msg) @@ -65,85 +79,103 @@ class PythonService: 'status':} """ - def __init__(self, - query_objects=None): + + def __init__(self, query_objects=None): self.EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.query_objects = query_objects or {} - def _load_object(self, object_uri, object_url, object_version, is_update, - object_type): + def _load_object( + self, object_uri, object_url, object_version, is_update, object_type + ): try: logger.info( - f'Loading object:, URI={object_uri}, ' - f'URL={object_url}, version={object_version}, ' - f'is_updated={is_update}') - if object_type == 'model': + f"Loading object:, URI={object_uri}, " + f"URL={object_url}, version={object_version}, " + f"is_updated={is_update}" + ) + if object_type == "model": po = QueryObject.load(object_url) - elif object_type == 'alias': + elif object_type == "alias": po = object_url else: - raise RuntimeError(f'Unknown object type: {object_type}') + raise RuntimeError(f"Unknown object type: {object_type}") - self.query_objects[object_uri] = {'version': object_version, - 'type': object_type, - 'endpoint_obj': po, - 'status': 'LoadSuccessful', - 'last_error': None} + self.query_objects[object_uri] = { + "version": object_version, + "type": object_type, + "endpoint_obj": po, + "status": "LoadSuccessful", + "last_error": None, + } except Exception as e: logger.exception(e) - logger.error(f'Unable to load QueryObject: path={object_url}, ' - f'error={str(e)}') + logger.error( + f"Unable to load QueryObject: path={object_url}, " f"error={str(e)}" + ) self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadFailed', - 'last_error': f'Load failed: {str(e)}'} - - def load_object(self, object_uri, object_url, object_version, is_update, - object_type): + "version": object_version, + "type": object_type, + "endpoint_obj": None, + "status": "LoadFailed", + "last_error": f"Load failed: {str(e)}", + } + + def load_object( + self, object_uri, object_url, object_version, is_update, object_type + ): try: obj_info = self.query_objects.get(object_uri) - if obj_info and obj_info['endpoint_obj'] and ( - obj_info['version'] >= object_version): - logger.info( - "Received load message for object already loaded") + if ( + obj_info + and obj_info["endpoint_obj"] + and (obj_info["version"] >= object_version) + ): + logger.info("Received load message for object already loaded") return DownloadSkipped( - object_uri, obj_info['version'], "Object with greater " - "or equal version already loaded") + object_uri, + obj_info["version"], + "Object with greater " "or equal version already loaded", + ) else: if object_uri not in self.query_objects: self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadInProgress', - 'last_error': None} + "version": object_version, + "type": object_type, + "endpoint_obj": None, + "status": "LoadInProgress", + "last_error": None, + } else: - self.query_objects[ - object_uri]['status'] = 'LoadInProgress' + self.query_objects[object_uri]["status"] = "LoadInProgress" self.EXECUTOR.submit( - self._load_object, object_uri, object_url, - object_version, is_update, object_type) + self._load_object, + object_uri, + object_url, + object_version, + is_update, + object_type, + ) return LoadInProgress( - object_uri, object_url, object_version, is_update, - object_type) + object_uri, object_url, object_version, is_update, object_type + ) except Exception as e: logger.exception(e) - logger.error(f'Unable to load QueryObject: path={object_url}, ' - f'error={str(e)}') + logger.error( + f"Unable to load QueryObject: path={object_url}, " f"error={str(e)}" + ) self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadFailed', - 'last_error': str(e)} + "version": object_version, + "type": object_type, + "endpoint_obj": None, + "status": "LoadFailed", + "last_error": str(e), + } return LoadFailed(object_uri, object_version, str(e)) @@ -159,15 +191,18 @@ def delete_objects(self, object_uris): if deleted_obj: return ObjectsDeleted([object_uris]) else: - logger.warning(f'Received message to delete query object ' - f'that doesn\'t exist: ' - f'object_uris={object_uris}') + logger.warning( + f"Received message to delete query object " + f"that doesn't exist: " + f"object_uris={object_uris}" + ) return ObjectsDeleted([]) else: logger.error( - f'Unexpected input to delete objects: input={object_uris}, ' + f"Unexpected input to delete objects: input={object_uris}, " f'info="Input should be list or str. ' - f'Type: {type(object_uris)}"') + f'Type: {type(object_uris)}"' + ) return ObjectsDeleted([]) def flush_objects(self): @@ -180,8 +215,8 @@ def flush_objects(self): def count_objects(self): """Count the number of Loaded QueryObjects stored in memory""" count = 0 - for uri, po in (self.query_objects.items()): - if po['endpoint_obj'] is not None: + for uri, po in self.query_objects.items(): + if po["endpoint_obj"] is not None: count += 1 return ObjectCount(count) @@ -189,37 +224,45 @@ def list_objects(self): """List the objects as (URI, version) pairs""" objects = {} - for (uri, obj_info) in (self.query_objects.items()): - objects[uri] = {'version': obj_info['version'], - 'type': obj_info['type'], - 'status': obj_info['status'], - 'reason': obj_info['last_error']} + for (uri, obj_info) in self.query_objects.items(): + objects[uri] = { + "version": obj_info["version"], + "type": obj_info["type"], + "status": obj_info["status"], + "reason": obj_info["last_error"], + } return ObjectList(objects) def query(self, object_uri, params, uid): """Execute a QueryObject query""" - logger.debug(f'Querying Python service {object_uri}...') + logger.debug(f"Querying Python service {object_uri}...") try: if not isinstance(params, dict) and not isinstance(params, list): return QueryFailed( uri=object_uri, - error=('Query parameter needs to be a dictionary or a list' - f'. Given value is of type {type(params)}')) + error=( + "Query parameter needs to be a dictionary or a list" + f". Given value is of type {type(params)}" + ), + ) obj_info = self.query_objects.get(object_uri) - logger.debug(f'Found object {obj_info}') + logger.debug(f"Found object {obj_info}") if obj_info: - pred_obj = obj_info['endpoint_obj'] - version = obj_info['version'] + pred_obj = obj_info["endpoint_obj"] + version = obj_info["version"] if not pred_obj: return QueryFailed( uri=object_uri, - error=("There is no query object associated to the " - f'endpoint: {object_uri}')) + error=( + "There is no query object associated to the " + f"endpoint: {object_uri}" + ), + ) - logger.debug(f'Querying endpoint with params ({params})...') + logger.debug(f"Querying endpoint with params ({params})...") if isinstance(params, dict): result = pred_obj.query(**params) else: @@ -230,6 +273,6 @@ def query(self, object_uri, params, uid): return UnknownURI(object_uri) except Exception as e: logger.exception(e) - err_msg = format_exception(e, '/query') + err_msg = format_exception(e, "/query") logger.error(err_msg) return QueryFailed(uri=object_uri, error=err_msg) diff --git a/tabpy/tabpy_tools/client.py b/tabpy/tabpy_tools/client.py index a8ec4076..29122107 100755 --- a/tabpy/tabpy_tools/client.py +++ b/tabpy/tabpy_tools/client.py @@ -3,16 +3,9 @@ import sys import requests -from .rest import ( - RequestsNetworkWrapper, - ServiceClient -) +from .rest import RequestsNetworkWrapper, ServiceClient -from .rest_client import ( - RESTServiceClient, - Endpoint, - AliasEndpoint -) +from .rest_client import RESTServiceClient, Endpoint, AliasEndpoint from .custom_query_object import CustomQueryObject @@ -21,27 +14,27 @@ logger = logging.getLogger(__name__) -_name_checker = compile(r'^[\w -]+$') +_name_checker = compile(r"^[\w -]+$") def _check_endpoint_type(name): if not isinstance(name, str): raise TypeError("Endpoint name must be a string") - if name == '': + if name == "": raise ValueError("Endpoint name cannot be empty") def _check_hostname(name): _check_endpoint_type(name) - hostname_checker = compile( - r'^http(s)?://[a-zA-Z0-9-_\.]+(/)?(:[0-9]+)?(/)?$') + hostname_checker = compile(r"^http(s)?://[a-zA-Z0-9-_\.]+(/)?(:[0-9]+)?(/)?$") if not hostname_checker.match(name): raise ValueError( - f'endpoint name {name} should be in http(s)://' - '[:] and hostname may consist only of: ' - 'a-z, A-Z, 0-9, underscore and hyphens.') + f"endpoint name {name} should be in http(s)://" + "[:] and hostname may consist only of: " + "a-z, A-Z, 0-9, underscore and hyphens." + ) def _check_endpoint_name(name): @@ -51,14 +44,13 @@ def _check_endpoint_name(name): if not _name_checker.match(name): raise ValueError( - f'endpoint name {name} can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.') + f"endpoint name {name} can only contain: a-z, A-Z, 0-9," + " underscore, hyphens and spaces." + ) class Client: - def __init__(self, - endpoint, - query_timeout=1000): + def __init__(self, endpoint, query_timeout=1000): """ Connects to a running server. @@ -93,12 +85,17 @@ def __init__(self, def __repr__(self): return ( - "<" + self.__class__.__name__ + - ' object at ' + hex(id(self)) + - ' connected to ' + repr(self._endpoint) + ">") + "<" + + self.__class__.__name__ + + " object at " + + hex(id(self)) + + " connected to " + + repr(self._endpoint) + + ">" + ) def get_status(self): - ''' + """ Gets the status of the deployed endpoints. Returns @@ -118,7 +115,7 @@ def get_status(self): u'type': u'model', }, } - ''' + """ return self._service.get_status() # @@ -193,11 +190,9 @@ def get_endpoints(self, type=None): def _get_endpoint_upload_destination(self): """Returns the endpoint upload destination.""" - return self._service.get_endpoint_upload_destination()['path'] + return self._service.get_endpoint_upload_destination()["path"] - def deploy(self, - name, obj, description='', schema=None, - override=False): + def deploy(self, name, obj, description="", schema=None, override=False): """Deploys a Python function as an endpoint in the server. Parameters @@ -234,9 +229,10 @@ def deploy(self, if endpoint: if not override: raise RuntimeError( - f'An endpoint with that name ({name}) already' + f"An endpoint with that name ({name}) already" ' exists. Use "override = True" to force update ' - 'an existing endpoint.') + "an existing endpoint." + ) version = endpoint.version + 1 else: @@ -251,10 +247,10 @@ def deploy(self, else: self._service.set_endpoint(Endpoint(**obj)) - self._wait_for_endpoint_deployment(obj['name'], obj['version']) + self._wait_for_endpoint_deployment(obj["name"], obj["version"]) def _gen_endpoint(self, name, obj, description, version=1, schema=[]): - '''Generates an endpoint dict. + """Generates an endpoint dict. Parameters ---------- @@ -295,55 +291,48 @@ def _gen_endpoint(self, name, obj, description, version=1, schema=[]): ------ TypeError When obj is not one of the expected types. - ''' + """ # check for invalid PO names _check_endpoint_name(name) if description is None: if isinstance(obj.__doc__, str): # extract doc string - description = obj.__doc__.strip() or '' + description = obj.__doc__.strip() or "" else: - description = '' + description = "" - endpoint_object = CustomQueryObject( - query=obj, - description=description, - ) + endpoint_object = CustomQueryObject(query=obj, description=description,) return { - 'name': name, - 'version': version, - 'description': description, - 'type': 'model', - 'endpoint_obj': endpoint_object, - 'dependencies': endpoint_object.get_dependencies(), - 'methods': endpoint_object.get_methods(), - 'required_files': [], - 'required_packages': [], - 'schema': schema + "name": name, + "version": version, + "description": description, + "type": "model", + "endpoint_obj": endpoint_object, + "dependencies": endpoint_object.get_dependencies(), + "methods": endpoint_object.get_methods(), + "required_files": [], + "required_packages": [], + "schema": schema, } def _upload_endpoint(self, obj): """Sends the endpoint across the wire.""" - endpoint_obj = obj['endpoint_obj'] + endpoint_obj = obj["endpoint_obj"] dest_path = self._get_endpoint_upload_destination() # Upload the endpoint - obj['src_path'] = os.path.join( - dest_path, - 'endpoints', - obj['name'], - str(obj['version'])) - - endpoint_obj.save(obj['src_path']) - - def _wait_for_endpoint_deployment(self, - endpoint_name, - version=1, - interval=1.0, - ): + obj["src_path"] = os.path.join( + dest_path, "endpoints", obj["name"], str(obj["version"]) + ) + + endpoint_obj.save(obj["src_path"]) + + def _wait_for_endpoint_deployment( + self, endpoint_name, version=1, interval=1.0, + ): """ Waits for the endpoint to be deployed by calling get_status() and checking the versions deployed of the endpoint against the expected @@ -351,25 +340,25 @@ def _wait_for_endpoint_deployment(self, expected, then it will return. Uses time.sleep(). """ logger.info( - f'Waiting for endpoint {endpoint_name} to deploy to ' - f'version {version}') + f"Waiting for endpoint {endpoint_name} to deploy to " f"version {version}" + ) start = time.time() while True: ep_status = self.get_status() try: ep = ep_status[endpoint_name] except KeyError: - logger.info(f'Endpoint {endpoint_name} doesn\'t ' - 'exist in endpoints yet') + logger.info( + f"Endpoint {endpoint_name} doesn't " "exist in endpoints yet" + ) else: - logger.info(f'ep={ep}') + logger.info(f"ep={ep}") - if ep['status'] == 'LoadFailed': - raise RuntimeError( - f'LoadFailed: {ep["last_error"]}') + if ep["status"] == "LoadFailed": + raise RuntimeError(f'LoadFailed: {ep["last_error"]}') - elif ep['status'] == 'LoadSuccessful': - if ep['version'] >= version: + elif ep["status"] == "LoadSuccessful": + if ep["version"] >= version: logger.info("LoadSuccessful") break else: @@ -378,11 +367,11 @@ def _wait_for_endpoint_deployment(self, if time.time() - start > 10: raise RuntimeError("Waited more then 10s for deployment") - logger.info(f'Sleeping {interval}...') + logger.info(f"Sleeping {interval}...") time.sleep(interval) def set_credentials(self, username, password): - ''' + """ Set credentials for all the TabPy client-server communication where client is tabpy-tools and server is tabpy-server. @@ -393,5 +382,5 @@ def set_credentials(self, username, password): password : str Password in plain text. - ''' + """ self._service.set_credentials(username, password) diff --git a/tabpy/tabpy_tools/custom_query_object.py b/tabpy/tabpy_tools/custom_query_object.py index a0e0f116..18a149b8 100755 --- a/tabpy/tabpy_tools/custom_query_object.py +++ b/tabpy/tabpy_tools/custom_query_object.py @@ -6,8 +6,8 @@ class CustomQueryObject(_QueryObject): - def __init__(self, query, description=''): - '''Create a new CustomQueryObject. + def __init__(self, query, description=""): + """Create a new CustomQueryObject. Parameters ----------- @@ -20,13 +20,13 @@ def __init__(self, query, description=''): description : str The description of the custom query object - ''' + """ super().__init__(description) self.custom_query = query def query(self, *args, **kwargs): - '''Query the custom defined query method using the given input. + """Query the custom defined query method using the given input. Parameters ---------- @@ -45,30 +45,32 @@ def query(self, *args, **kwargs): See Also -------- QueryObject - ''' + """ # include the dependent files in sys path so that the query can run # correctly try: - logger.debug('Running custom query with arguments ' - f'({args}, {kwargs})...') + logger.debug( + "Running custom query with arguments " f"({args}, {kwargs})..." + ) ret = self.custom_query(*args, **kwargs) except Exception as e: logger.exception( - 'Exception hit when running custom query, error: ' - f'{str(e)}') + "Exception hit when running custom query, error: " f"{str(e)}" + ) raise - logger.debug(f'Received response {ret}') + logger.debug(f"Received response {ret}") try: return self._make_serializable(ret) except Exception as e: - logger.exception('Cannot properly serialize custom query result, ' - f'error: {str(e)}') + logger.exception( + "Cannot properly serialize custom query result, " f"error: {str(e)}" + ) raise def get_doc_string(self): - '''Get doc string from customized query''' + """Get doc string from customized query""" if self.custom_query.__doc__ is not None: return self.custom_query.__doc__ else: @@ -78,4 +80,4 @@ def get_methods(self): return [self.get_query_method()] def get_query_method(self): - return {'method': 'query'} + return {"method": "query"} diff --git a/tabpy/tabpy_tools/query_object.py b/tabpy/tabpy_tools/query_object.py index b795aa50..5ccbc109 100755 --- a/tabpy/tabpy_tools/query_object.py +++ b/tabpy/tabpy_tools/query_object.py @@ -11,17 +11,17 @@ class QueryObject(abc.ABC): - ''' + """ Derived class needs to implement the following interface: * query() -- given input, return query result * get_doc_string() -- returns documentation for the Query Object - ''' + """ - def __init__(self, description=''): + def __init__(self, description=""): self.description = description def get_dependencies(self): - '''All endpoints this endpoint depends on''' + """All endpoints this endpoint depends on""" return [] @abc.abstractmethod @@ -31,11 +31,11 @@ def query(self, input): @abc.abstractmethod def get_doc_string(self): - '''Returns documentation for the query object + """Returns documentation for the query object By default, this method returns the docstring for 'query' method Derived class may overwrite this method to dynamically create docstring - ''' + """ pass def save(self, path): @@ -48,18 +48,20 @@ def save(self, path): """ if os.path.exists(path): logger.warning( - f'Overwriting existing file "{path}" when saving query object') + f'Overwriting existing file "{path}" when saving query object' + ) rm_fn = os.remove if os.path.isfile(path) else shutil.rmtree rm_fn(path) self._save_local(path) def _save_local(self, path): - '''Save current query object to local path - ''' + """Save current query object to local path + """ try: os.makedirs(path) except OSError as e: import errno + if e.errno == errno.EEXIST and os.path.isdir(path): pass else: @@ -75,8 +77,7 @@ def load(cls, path): new_po = None new_po = cls._load_local(path) - logger.info( - f'Loaded query object "{type(new_po).__name__}" successfully') + logger.info(f'Loaded query object "{type(new_po).__name__}" successfully') return new_po @@ -88,15 +89,15 @@ def _load_local(cls, path): @classmethod def _make_serializable(cls, result): - '''Convert a result from object query to python data structure that can + """Convert a result from object query to python data structure that can easily serialize over network - ''' + """ try: json.dumps(result) except TypeError: raise TypeError( - 'Result from object query is not json serializable: ' - f'{result}') + "Result from object query is not json serializable: " f"{result}" + ) return result diff --git a/tabpy/tabpy_tools/rest.py b/tabpy/tabpy_tools/rest.py index 3f5125cc..8959fc11 100755 --- a/tabpy/tabpy_tools/rest.py +++ b/tabpy/tabpy_tools/rest.py @@ -20,17 +20,14 @@ def __init__(self, response): try: r = response.json() - self.info = r['info'] - self.message = response.json()['message'] - except (json.JSONDecodeError, - KeyError): + self.info = r["info"] + self.message = response.json()["message"] + except (json.JSONDecodeError, KeyError): self.info = None self.message = response.text def __str__(self): - return (f'({self.status_code}) ' - f'{self.message} ' - f'{self.info}') + return f"({self.status_code}) " f"{self.message} " f"{self.info}" class RequestsNetworkWrapper: @@ -58,8 +55,9 @@ def __init__(self, session=None): @staticmethod def raise_error(response): logger.error( - f'Error with server response. code={response.status_code}; ' - f'text={response.text}') + f"Error with server response. code={response.status_code}; " + f"text={response.text}" + ) raise ResponseError(response) @@ -82,18 +80,14 @@ def GET(self, url, data, timeout=None): object that is parsed from the response JSON.""" self._remove_nones(data) - logger.info(f'GET {url} with {data}') + logger.info(f"GET {url} with {data}") - response = self.session.get( - url, - params=data, - timeout=timeout, - auth=self.auth) + response = self.session.get(url, params=data, timeout=timeout, auth=self.auth) if response.status_code != 200: self.raise_error(response) - logger.info(f'response={response.text}') + logger.info(f"response={response.text}") - if response.text == '': + if response.text == "": return dict() else: return response.json() @@ -103,15 +97,14 @@ def POST(self, url, data, timeout=None): object that is parsed from the response JSON.""" data = self._encode_request(data) - logger.info(f'POST {url} with {data}') + logger.info(f"POST {url} with {data}") response = self.session.post( url, data=data, - headers={ - 'content-type': 'application/json', - }, + headers={"content-type": "application/json",}, timeout=timeout, - auth=self.auth) + auth=self.auth, + ) if response.status_code not in (200, 201): self.raise_error(response) @@ -123,43 +116,39 @@ def PUT(self, url, data, timeout=None): object that is parsed from the response JSON.""" data = self._encode_request(data) - logger.info(f'PUT {url} with {data}') + logger.info(f"PUT {url} with {data}") response = self.session.put( url, data=data, - headers={ - 'content-type': 'application/json', - }, + headers={"content-type": "application/json",}, timeout=timeout, - auth=self.auth) + auth=self.auth, + ) if response.status_code != 200: self.raise_error(response) return response.json() def DELETE(self, url, data, timeout=None): - ''' + """ Issues a DELETE request to the URL with the data specified. Returns an object that is parsed from the response JSON. - ''' + """ if data is not None: data = json.dumps(data) - logger.info(f'DELETE {url} with {data}') + logger.info(f"DELETE {url} with {data}") - response = self.session.delete( - url, - data=data, - timeout=timeout, - auth=self.auth) + response = self.session.delete(url, data=data, timeout=timeout, auth=self.auth) if response.status_code <= 499 and response.status_code >= 400: raise RuntimeError(response.text) if response.status_code not in (200, 201, 204): raise RuntimeError( - f'Error with server response code: {response.status_code}') + f"Error with server response code: {response.status_code}" + ) def set_credentials(self, username, password): """ @@ -174,7 +163,7 @@ def set_credentials(self, username, password): password : str Password in plain text. """ - logger.info(f'Setting credentials (username: {username})') + logger.info(f"Setting credentials (username: {username})") self.auth = HTTPBasicAuth(username, password) @@ -190,16 +179,14 @@ class ServiceClient: def __init__(self, endpoint, network_wrapper=None): if network_wrapper is None: - network_wrapper = RequestsNetworkWrapper( - session=requests.session()) + network_wrapper = RequestsNetworkWrapper(session=requests.session()) self.network_wrapper = network_wrapper - pattern = compile('.*(:[0-9]+)$') - if not endpoint.endswith('/') and not pattern.match(endpoint): - logger.warning( - f'endpoint {endpoint} does not end with \'/\': appending.') - endpoint = endpoint + '/' + pattern = compile(".*(:[0-9]+)$") + if not endpoint.endswith("/") and not pattern.match(endpoint): + logger.warning(f"endpoint {endpoint} does not end with '/': appending.") + endpoint = endpoint + "/" self.endpoint = endpoint @@ -220,7 +207,7 @@ def DELETE(self, url, data=None, timeout=None): self.network_wrapper.DELETE(self.endpoint + url, data, timeout) def set_credentials(self, username, password): - ''' + """ Set credentials for all the TabPy client-server communication where client is tabpy-tools and server is tabpy-server. @@ -231,15 +218,14 @@ def set_credentials(self, username, password): password : str Password in plain text. - ''' + """ self.network_wrapper.set_credentials(username, password) class RESTProperty: """A descriptor that will control the type of value stored.""" - def __init__(self, type, from_json=lambda x: x, to_json=lambda x: x, - doc=None): + def __init__(self, type, from_json=lambda x: x, to_json=lambda x: x, doc=None): self.__doc__ = doc self.type = type self.from_json = from_json @@ -250,8 +236,7 @@ def __get__(self, instance, _): try: return getattr(instance, self.name) except AttributeError: - raise AttributeError( - f'{self.name} has not been set yet.') + raise AttributeError(f"{self.name} has not been set yet.") else: return self @@ -281,11 +266,11 @@ def __init__(self, name, bases, dict): self.__rest__ = set() for base in bases: - self.__rest__.update(getattr(base, '__rest__', set())) + self.__rest__.update(getattr(base, "__rest__", set())) for k, v in dict.items(): if isinstance(v, RESTProperty): - v.__dict__['name'] = '_' + k + v.__dict__["name"] = "_" + k self.__rest__.add(k) @@ -304,6 +289,7 @@ class RESTObject(MutableMapping, metaclass=_RESTMetaclass): addition RESTProperty. """ + """ __metaclass__ = _RESTMetaclass""" def __init__(self, **kwargs): @@ -317,20 +303,14 @@ def __init__(self, **kwargs): are ignored. """ - logger.info( - f'Initializing {self.__class__.__name__} from {kwargs}') + logger.info(f"Initializing {self.__class__.__name__} from {kwargs}") for attr in self.__rest__: if attr in kwargs: setattr(self, attr, kwargs.pop(attr)) def __repr__(self): return ( - "{" + - ", ".join([ - repr(k) + ": " + repr(v) - for k, v in self.items() - ]) + - "}" + "{" + ", ".join([repr(k) + ": " + repr(v) for k, v in self.items()]) + "}" ) @classmethod @@ -365,17 +345,15 @@ def to_json(self): return result def __eq__(self, other): - return (isinstance(self, type(other)) and - all(( - getattr(self, a) == getattr(other, a) - for a in self.__rest__ - ))) + return isinstance(self, type(other)) and all( + (getattr(self, a) == getattr(other, a) for a in self.__rest__) + ) def __len__(self): - return len([a for a in self.__rest__ if hasattr(self, '_' + a)]) + return len([a for a in self.__rest__ if hasattr(self, "_" + a)]) def __iter__(self): - return iter([a for a in self.__rest__ if hasattr(self, '_' + a)]) + return iter([a for a in self.__rest__ if hasattr(self, "_" + a)]) def __getitem__(self, item): if item not in self.__rest__: @@ -394,7 +372,7 @@ def __delitem__(self, item): if item not in self.__rest__: raise KeyError(item) try: - delattr(self, '_' + item) + delattr(self, "_" + item) except AttributeError: raise KeyError(item) @@ -428,25 +406,18 @@ def enum(*values, **kwargs): """ if len(values) < 1: raise ValueError("At least one value is required.") - enum_type = kwargs.pop('type', str) + enum_type = kwargs.pop("type", str) if kwargs: - raise TypeError( - f'Unexpected parameters: {", ".join(kwargs.keys())}') + raise TypeError(f'Unexpected parameters: {", ".join(kwargs.keys())}') def __new__(cls, value): if value not in cls.values: raise ValueError( - f'{value} is an unexpected value. ' - f'Expected one of {cls.values}') + f"{value} is an unexpected value. " f"Expected one of {cls.values}" + ) return super(enum, cls).__new__(cls, value) - enum = type( - 'Enum', - (enum_type,), - { - 'values': values, - '__new__': __new__, - }) + enum = type("Enum", (enum_type,), {"values": values, "__new__": __new__,}) return enum diff --git a/tabpy/tabpy_tools/rest_client.py b/tabpy/tabpy_tools/rest_client.py index 8bb84cd5..379a3720 100755 --- a/tabpy/tabpy_tools/rest_client.py +++ b/tabpy/tabpy_tools/rest_client.py @@ -35,6 +35,7 @@ class Endpoint(RESTObject): methods : list ??? """ + name = RESTProperty(str) type = RESTProperty(str) version = RESTProperty(int) @@ -49,25 +50,24 @@ class Endpoint(RESTObject): def __new__(cls, **kwargs): """Dispatch to the appropriate class.""" - cls = { - 'alias': AliasEndpoint, - 'model': ModelEndpoint, - }[kwargs['type']] + cls = {"alias": AliasEndpoint, "model": ModelEndpoint,}[kwargs["type"]] """return object.__new__(cls, **kwargs)""" """ modified for Python 3""" return object.__new__(cls) def __eq__(self, other): - return self.name == other.name and \ - self.type == other.type and \ - self.version == other.version and \ - self.description == other.description and \ - self.dependencies == other.dependencies and \ - self.methods == other.methods and \ - self.evaluator == other.evaluator and \ - self.schema_version == other.schema_version and \ - self.schema == other.schema + return ( + self.name == other.name + and self.type == other.type + and self.version == other.version + and self.description == other.description + and self.dependencies == other.dependencies + and self.methods == other.methods + and self.evaluator == other.evaluator + and self.schema_version == other.schema_version + and self.schema == other.schema + ) class ModelEndpoint(Endpoint): @@ -88,6 +88,7 @@ class ModelEndpoint(Endpoint): required packages. """ + src_path = RESTProperty(str) required_files = RESTProperty(list) required_packages = RESTProperty(list) @@ -95,12 +96,14 @@ class ModelEndpoint(Endpoint): def __init__(self, **kwargs): super().__init__(**kwargs) - self.type = 'model' + self.type = "model" def __eq__(self, other): - return super().__eq__(other) and \ - self.required_files == other.required_files and \ - self.required_packages == other.required_packages + return ( + super().__eq__(other) + and self.required_files == other.required_files + and self.required_packages == other.required_packages + ) class AliasEndpoint(Endpoint): @@ -111,33 +114,36 @@ class AliasEndpoint(Endpoint): The endpoint that this is an alias for. """ + target = RESTProperty(str) def __init__(self, **kwargs): super().__init__(**kwargs) - self.type = 'alias' + self.type = "alias" class RESTServiceClient: """A thin client for the REST Service.""" + def __init__(self, service_client): self.service_client = service_client self.query_timeout = None def get_info(self): """Returns the /info""" - return self.service_client.GET('info') + return self.service_client.GET("info") def query(self, name, *args, **kwargs): """Performs a query. Either specify *args or **kwargs, not both. Respects query_timeout.""" if args and kwargs: raise ValueError( - 'Mixing of keyword arguments and positional arguments when ' - 'querying an endpoint is not supported.') - return self.service_client.POST('query/' + name, - data={'data': args or kwargs}, - timeout=self.query_timeout) + "Mixing of keyword arguments and positional arguments when " + "querying an endpoint is not supported." + ) + return self.service_client.POST( + "query/" + name, data={"data": args or kwargs}, timeout=self.query_timeout + ) def get_endpoint_upload_destination(self): """Returns a dict representing where endpoint data should be uploaded. @@ -152,8 +158,7 @@ def get_endpoint_upload_destination(self): Note: At this time, the response should not change over time. """ - return self.service_client.GET( - 'configurations/endpoint_upload_destination') + return self.service_client.GET("configurations/endpoint_upload_destination") def get_endpoints(self, type=None): """Returns endpoints from the management API. @@ -166,9 +171,7 @@ def get_endpoints(self, type=None): Other options are 'model' and 'alias'. """ result = {} - for name, attrs in self.service_client.GET( - 'endpoints', - {'type': type}).items(): + for name, attrs in self.service_client.GET("endpoints", {"type": type}).items(): endpoint = Endpoint.from_json(attrs) endpoint.name = name result[name] = endpoint @@ -184,8 +187,7 @@ def get_endpoint(self, endpoint_name): The name of the endpoint. """ - ((name, attrs),) = self.service_client.GET( - 'endpoints/' + endpoint_name).items() + ((name, attrs),) = self.service_client.GET("endpoints/" + endpoint_name).items() endpoint = Endpoint.from_json(attrs) endpoint.name = name return endpoint @@ -198,7 +200,7 @@ def add_endpoint(self, endpoint): endpoint : Endpoint """ - return self.service_client.POST('endpoints', endpoint.to_json()) + return self.service_client.POST("endpoints", endpoint.to_json()) def set_endpoint(self, endpoint): """Updates an endpoint through the management API. @@ -210,8 +212,7 @@ def set_endpoint(self, endpoint): The endpoint to update. """ - return self.service_client.PUT( - 'endpoints/' + endpoint.name, endpoint.to_json()) + return self.service_client.PUT("endpoints/" + endpoint.name, endpoint.to_json()) def remove_endpoint(self, endpoint_name): """Deletes an endpoint through the management API. @@ -223,7 +224,7 @@ def remove_endpoint(self, endpoint_name): The endpoint to delete. """ - self.service_client.DELETE('endpoints/'+endpoint_name) + self.service_client.DELETE("endpoints/" + endpoint_name) def get_status(self): """Returns the status of the server. @@ -233,10 +234,10 @@ def get_status(self): dict """ - return self.service_client.GET('status') + return self.service_client.GET("status") def set_credentials(self, username, password): - ''' + """ Set credentials for all the TabPy client-server communication where client is tabpy-tools and server is tabpy-server. @@ -247,5 +248,5 @@ def set_credentials(self, username, password): password : str Password in plain text. - ''' + """ self.service_client.set_credentials(username, password) diff --git a/tabpy/tabpy_tools/schema.py b/tabpy/tabpy_tools/schema.py index 6fc32556..ba36bae2 100755 --- a/tabpy/tabpy_tools/schema.py +++ b/tabpy/tabpy_tools/schema.py @@ -7,55 +7,51 @@ def _generate_schema_from_example_and_description(input, description): - ''' + """ With an example input, a schema is automatically generated that conforms to the example in json-schema.org. The description given by the users is then added to the schema. - ''' + """ s = genson.SchemaBuilder(None) s.add_object(input) input_schema = s.to_schema() if description is not None: - if 'properties' in input_schema: + if "properties" in input_schema: # Case for input = {'x':1}, input_description='not a dict' if not isinstance(description, dict): - msg = f'{input} and {description} do not match' + msg = f"{input} and {description} do not match" logger.error(msg) raise Exception(msg) for key in description: # Case for input = {'x':1}, # input_description={'x':'x value', 'y':'y value'} - if key not in input_schema['properties']: - msg = f'{key} not found in {input}' + if key not in input_schema["properties"]: + msg = f"{key} not found in {input}" logger.error(msg) raise Exception(msg) else: - input_schema['properties'][key][ - 'description'] = description[key] + input_schema["properties"][key]["description"] = description[key] else: if isinstance(description, dict): - raise Exception(f'{input} and {description} do not match') + raise Exception(f"{input} and {description} do not match") else: - input_schema['description'] = description + input_schema["description"] = description try: # This should not fail unless there are bugs with either genson or # jsonschema. jsonschema.validate(input, input_schema) except Exception as e: - logger.error(f'Internal error validating schema: {str(e)}') + logger.error(f"Internal error validating schema: {str(e)}") raise return input_schema -def generate_schema(input, - output, - input_description=None, - output_description=None): - ''' +def generate_schema(input, output, input_description=None, output_description=None): + """ Generate schema from a given sample input and output. A generated schema can be passed to a server together with a function to annotate it with information about input and output parameters, and @@ -102,11 +98,11 @@ def generate_schema(input, 'properties': {'y': {'type': 'integer', 'description': 'value of y'}, 'x': {'type': 'integer', 'description': 'value of x'}}}, 'output': {'type': 'integer', 'description': 'x times y'}} - ''' # noqa: E501 + """ # noqa: E501 input_schema = _generate_schema_from_example_and_description( - input, input_description) + input, input_description + ) output_schema = _generate_schema_from_example_and_description( - output, output_description) - return {'input': input_schema, - 'sample': input, - 'output': output_schema} + output, output_description + ) + return {"input": input_schema, "sample": input, "output": output_schema} diff --git a/tabpy/utils/user_management.py b/tabpy/utils/user_management.py index 774af11c..3896ffb3 100755 --- a/tabpy/utils/user_management.py +++ b/tabpy/utils/user_management.py @@ -1,6 +1,6 @@ -''' +""" Utility for managing user names and passwords for TabPy. -''' +""" from argparse import ArgumentParser import logging @@ -16,30 +16,27 @@ def build_cli_parser(): parser = ArgumentParser( description=__doc__, - epilog=''' + epilog=""" For more information about how to configure and use authentication for TabPy read the documentation at https://github.com/tableau/TabPy - ''', + """, argument_default=None, - add_help=True) + add_help=True, + ) + parser.add_argument("command", choices=["add", "update"], help="Command to execute") + parser.add_argument("-u", "--username", help="Username to add to passwords file") parser.add_argument( - 'command', - choices=['add', 'update'], - help='Command to execute') + "-f", "--pwdfile", help="Fully qualified path to passwords file" + ) parser.add_argument( - '-u', - '--username', - help='Username to add to passwords file') - parser.add_argument( - '-f', - '--pwdfile', - help='Fully qualified path to passwords file') - parser.add_argument( - '-p', - '--password', - help=('Password for the username. If not specified a password will ' - 'be generated')) + "-p", + "--password", + help=( + "Password for the username. If not specified a password will " + "be generated" + ), + ) return parser @@ -54,33 +51,37 @@ def generate_password(pwd_len=16): # List of characters to generate password from. # We want to avoid to use similarly looking pairs like # (O, 0), (1, l), etc. - lower_case_letters = 'abcdefghijkmnpqrstuvwxyz' - upper_case_letters = 'ABCDEFGHIJKLMPQRSTUVWXYZ' - digits = '23456789' + lower_case_letters = "abcdefghijkmnpqrstuvwxyz" + upper_case_letters = "ABCDEFGHIJKLMPQRSTUVWXYZ" + digits = "23456789" # and for punctuation we want to exclude some characters # like inverted comma which can be hard to find and/or # type # change this string if you are supporting an # international keyboard with differing keys available - punctuation = '!#$%&()*+,-./:;<=>?@[\\]^_{|}~' + punctuation = "!#$%&()*+,-./:;<=>?@[\\]^_{|}~" # we also want to try to have more letters and digits in # generated password than punctuation marks - password_chars =\ - lower_case_letters + lower_case_letters +\ - upper_case_letters + upper_case_letters +\ - digits + digits +\ - punctuation - pwd = ''.join(secrets.choice(password_chars) for i in range(pwd_len)) + password_chars = ( + lower_case_letters + + lower_case_letters + + upper_case_letters + + upper_case_letters + + digits + + digits + + punctuation + ) + pwd = "".join(secrets.choice(password_chars) for i in range(pwd_len)) logger.info(f'Generated password: "{pwd}"') return pwd def store_passwords_file(pwdfile, credentials): - with open(pwdfile, 'wt') as f: + with open(pwdfile, "wt") as f: for username, pwd in credentials.items(): - f.write(f'{username} {pwd}\n') + f.write(f"{username} {pwd}\n") return True @@ -89,21 +90,23 @@ def add_user(args, credentials): logger.info(f'Adding username "{username}"') if username in credentials: - logger.error(f'Can\'t add username {username} as it is already present' - ' in passwords file. Do you want to run the ' - '"update" command instead?') + logger.error( + f"Can't add username {username} as it is already present" + " in passwords file. Do you want to run the " + '"update" command instead?' + ) return False password = args.password logger.info(f'Adding username "{username}" with password "{password}"...') credentials[username] = hash_password(username, password) - if(store_passwords_file(args.pwdfile, credentials)): + if store_passwords_file(args.pwdfile, credentials): logger.info(f'Added username "{username}" with password "{password}"') else: logger.info( - f'Could not add username "{username}" , ' - f'password "{password}" to file') + f'Could not add username "{username}" , ' f'password "{password}" to file' + ) def update_user(args, credentials): @@ -111,8 +114,10 @@ def update_user(args, credentials): logger.info(f'Updating username "{username}"') if username not in credentials: - logger.error(f'Username "{username}" not found in passwords file. ' - 'Do you want to run "add" command instead?') + logger.error( + f'Username "{username}" not found in passwords file. ' + 'Do you want to run "add" command instead?' + ) return False password = args.password @@ -122,9 +127,9 @@ def update_user(args, credentials): def process_command(args, credentials): - if args.command == 'add': + if args.command == "add": return add_user(args, credentials) - elif args.command == 'update': + elif args.command == "update": return update_user(args, credentials) else: logger.error(f'Unknown command "{args.command}"') @@ -141,7 +146,7 @@ def main(): return succeeded, credentials = parse_pwd_file(args.pwdfile) - if not succeeded and args.command != 'add': + if not succeeded and args.command != "add": return if args.password is None: @@ -151,5 +156,5 @@ def main(): return -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/integration/integ_test_base.py b/tests/integration/integ_test_base.py index 5ed8aa60..3e7104c2 100755 --- a/tests/integration/integ_test_base.py +++ b/tests/integration/integ_test_base.py @@ -13,9 +13,9 @@ class IntegTestBase(unittest.TestCase): - ''' + """ Base class for integration tests. - ''' + """ def __init__(self, methodName="runTest"): super(IntegTestBase, self).__init__(methodName) @@ -23,7 +23,7 @@ def __init__(self, methodName="runTest"): self.delete_temp_folder = True def set_delete_temp_folder(self, delete_temp_folder: bool): - ''' + """ Specify if temporary folder for state, config and log files should be deleted when test is done. By default the folder is deleted. @@ -32,11 +32,11 @@ def set_delete_temp_folder(self, delete_temp_folder: bool): ---------- delete_test_folder: bool If True temp folder will be deleted. - ''' + """ self.delete_temp_folder = delete_temp_folder def _get_state_file_path(self) -> str: - ''' + """ Generates state.ini and returns absolute path to it. Overwrite this function for tests to run against not default state file. @@ -45,29 +45,30 @@ def _get_state_file_path(self) -> str: ------- str Absolute path to state file folder. - ''' - state_file = open(os.path.join(self.tmp_dir, 'state.ini'), 'w+') + """ + state_file = open(os.path.join(self.tmp_dir, "state.ini"), "w+") state_file.write( - '[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) state_file.close() return self.tmp_dir def _get_port(self) -> str: - ''' + """ Returns port TabPy should run on. Default implementation returns '9004'. @@ -75,11 +76,11 @@ def _get_port(self) -> str: ------- str Port number. - ''' - return '9004' + """ + return "9004" def _get_pwd_file(self) -> str: - ''' + """ Returns absolute or relative path to password file. Overwrite to create and/or specify your own file. Default implementation returns None which means @@ -91,11 +92,11 @@ def _get_pwd_file(self) -> str: Absolute or relative path to password file. If None TABPY_PWD_FILE setting won't be added to config. - ''' + """ return None def _get_transfer_protocol(self) -> str: - ''' + """ Returns transfer protocol for configuration file. Default implementation returns None which means TABPY_TRANSFER_PROTOCOL setting won't be added to config. @@ -106,11 +107,11 @@ def _get_transfer_protocol(self) -> str: Transfer protocol (e.g 'http' or 'https'). If None TABPY_TRANSFER_PROTOCOL setting won't be added to config. - ''' + """ return None def _get_certificate_file_name(self) -> str: - ''' + """ Returns absolute or relative certificate file name for configuration file. Default implementation returns None which means @@ -122,11 +123,11 @@ def _get_certificate_file_name(self) -> str: Absolute or relative certificate file name. If None TABPY_CERTIFICATE_FILE setting won't be added to config. - ''' + """ return None def _get_key_file_name(self) -> str: - ''' + """ Returns absolute or relative private key file name for configuration file. Default implementation returns None which means @@ -138,11 +139,11 @@ def _get_key_file_name(self) -> str: Absolute or relative private key file name. If None TABPY_KEY_FILE setting won't be added to config. - ''' + """ return None def _get_evaluate_timeout(self) -> str: - ''' + """ Returns the configured timeout for the /evaluate method. Default implementation returns None, which means that the timeout will default to 30. @@ -153,11 +154,11 @@ def _get_evaluate_timeout(self) -> str: Timeout for calling /evaluate. If None, defaults TABPY_EVALUATE_TIMEOUT setting will default to '30'. - ''' + """ return None def _get_config_file_name(self) -> str: - ''' + """ Generates config file. Overwrite this function for tests to run against not default state file. @@ -165,37 +166,37 @@ def _get_config_file_name(self) -> str: ------- str Absolute path to config file. - ''' - config_file = open(os.path.join(self.tmp_dir, 'test.conf'), 'w+') + """ + config_file = open(os.path.join(self.tmp_dir, "test.conf"), "w+") config_file.write( - '[TabPy]\n' - f'TABPY_QUERY_OBJECT_PATH = ./query_objects\n' - f'TABPY_PORT = {self._get_port()}\n' - f'TABPY_STATE_PATH = {self.tmp_dir}\n') + "[TabPy]\n" + f"TABPY_QUERY_OBJECT_PATH = ./query_objects\n" + f"TABPY_PORT = {self._get_port()}\n" + f"TABPY_STATE_PATH = {self.tmp_dir}\n" + ) pwd_file = self._get_pwd_file() if pwd_file is not None: pwd_file = os.path.abspath(pwd_file) - config_file.write(f'TABPY_PWD_FILE = {pwd_file}\n') + config_file.write(f"TABPY_PWD_FILE = {pwd_file}\n") transfer_protocol = self._get_transfer_protocol() if transfer_protocol is not None: - config_file.write( - f'TABPY_TRANSFER_PROTOCOL = {transfer_protocol}\n') + config_file.write(f"TABPY_TRANSFER_PROTOCOL = {transfer_protocol}\n") cert_file_name = self._get_certificate_file_name() if cert_file_name is not None: cert_file_name = os.path.abspath(cert_file_name) - config_file.write(f'TABPY_CERTIFICATE_FILE = {cert_file_name}\n') + config_file.write(f"TABPY_CERTIFICATE_FILE = {cert_file_name}\n") key_file_name = self._get_key_file_name() if key_file_name is not None: key_file_name = os.path.abspath(key_file_name) - config_file.write(f'TABPY_KEY_FILE = {key_file_name}\n') + config_file.write(f"TABPY_KEY_FILE = {key_file_name}\n") evaluate_timeout = self._get_evaluate_timeout() if evaluate_timeout is not None: - config_file.write(f'TABPY_EVALUATE_TIMEOUT = {evaluate_timeout}\n') + config_file.write(f"TABPY_EVALUATE_TIMEOUT = {evaluate_timeout}\n") config_file.close() @@ -204,42 +205,40 @@ def _get_config_file_name(self) -> str: def setUp(self): super(IntegTestBase, self).setUp() - prefix = 'TabPy_IntegTest_' + prefix = "TabPy_IntegTest_" self.tmp_dir = tempfile.mkdtemp(prefix=prefix) # create temporary state.ini orig_state_file_name = os.path.abspath( - self._get_state_file_path() + '/state.ini') - self.state_file_name = os.path.abspath(self.tmp_dir + '/state.ini') + self._get_state_file_path() + "/state.ini" + ) + self.state_file_name = os.path.abspath(self.tmp_dir + "/state.ini") if orig_state_file_name != self.state_file_name: shutil.copyfile(orig_state_file_name, self.state_file_name) # create config file orig_config_file_name = os.path.abspath(self._get_config_file_name()) self.config_file_name = os.path.abspath( - self.tmp_dir + '/' + - os.path.basename(orig_config_file_name)) + self.tmp_dir + "/" + os.path.basename(orig_config_file_name) + ) if orig_config_file_name != self.config_file_name: shutil.copyfile(orig_config_file_name, self.config_file_name) # Platform specific - for integration tests we want to engage # startup script - with open(self.tmp_dir + '/output.txt', 'w') as outfile: - cmd = ['tabpy', - '--config=' + self.config_file_name] + with open(self.tmp_dir + "/output.txt", "w") as outfile: + cmd = ["tabpy", "--config=" + self.config_file_name] preexec_fn = None - if platform.system() == 'Windows': - self.py = 'python' + if platform.system() == "Windows": + self.py = "python" else: - self.py = 'python3' + self.py = "python3" preexec_fn = os.setsid coverage.process_startup() self.process = subprocess.Popen( - cmd, - preexec_fn=preexec_fn, - stdout=outfile, - stderr=outfile) + cmd, preexec_fn=preexec_fn, stdout=outfile, stderr=outfile + ) # give the app some time to start up... time.sleep(5) @@ -247,9 +246,8 @@ def setUp(self): def tearDown(self): # stop TabPy if self.process is not None: - if platform.system() == 'Windows': - subprocess.call(['taskkill', '/F', '/T', '/PID', - str(self.process.pid)]) + if platform.system() == "Windows": + subprocess.call(["taskkill", "/F", "/T", "/PID", str(self.process.pid)]) else: os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) self.process.kill() @@ -268,9 +266,9 @@ def tearDown(self): def _get_connection(self) -> http.client.HTTPConnection: protocol = self._get_transfer_protocol() - url = 'localhost:' + self._get_port() + url = "localhost:" + self._get_port() - if protocol is not None and protocol.lower() == 'https': + if protocol is not None and protocol.lower() == "https": connection = http.client.HTTPSConnection(url) else: connection = http.client.HTTPConnection(url) @@ -278,26 +276,28 @@ def _get_connection(self) -> http.client.HTTPConnection: return connection def _get_username(self) -> str: - return 'user1' + return "user1" def _get_password(self) -> str: - return 'P@ssw0rd' + return "P@ssw0rd" def deploy_models(self, username: str, password: str): repo_dir = os.path.abspath(os.path.dirname(tabpy.__file__)) - path = os.path.join(repo_dir, 'models', 'deploy_models.py') - with open(self.tmp_dir + '/deploy_models_output.txt', 'w') as outfile: + path = os.path.join(repo_dir, "models", "deploy_models.py") + with open(self.tmp_dir + "/deploy_models_output.txt", "w") as outfile: outfile.write( - f'--<< Running {self.py} {path} ' - f'{self._get_config_file_name()} >>--\n') - input_string = f'{username}\n{password}\n' - outfile.write(f'--<< Input = {input_string} >>--') + f"--<< Running {self.py} {path} " + f"{self._get_config_file_name()} >>--\n" + ) + input_string = f"{username}\n{password}\n" + outfile.write(f"--<< Input = {input_string} >>--") coverage.process_startup() p = subprocess.run( [self.py, path, self._get_config_file_name()], - input=input_string.encode('utf-8'), + input=input_string.encode("utf-8"), stdout=outfile, - stderr=outfile) + stderr=outfile, + ) def _get_process(self): return self.process diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index c730d8ba..a3d495bb 100755 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -5,20 +5,19 @@ class TestAuth(integ_test_base.IntegTestBase): def setUp(self): super(TestAuth, self).setUp() - self.payload = ( - '''{ + self.payload = """{ "data": { "_arg1": [1, 2] }, "script": "return [x * 2 for x in _arg1]" - }''') + }""" def _get_pwd_file(self) -> str: - return './tests/integration/resources/pwdfile.txt' + return "./tests/integration/resources/pwdfile.txt" def test_missing_credentials_fails(self): headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration tests for Auth" - } + "Content-Type": "application/json", + "TabPy-Client": "Integration tests for Auth", + } conn = self._get_connection() conn.request("POST", "/evaluate", self.payload, headers) @@ -28,13 +27,11 @@ def test_missing_credentials_fails(self): def test_invalid_password(self): headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration tests for Auth", - 'Authorization': - 'Basic ' + - base64.b64encode('user1:wrong_password'.encode('utf-8')). - decode('utf-8') - } + "Content-Type": "application/json", + "TabPy-Client": "Integration tests for Auth", + "Authorization": "Basic " + + base64.b64encode("user1:wrong_password".encode("utf-8")).decode("utf-8"), + } conn = self._get_connection() conn.request("POST", "/evaluate", self.payload, headers) @@ -44,13 +41,11 @@ def test_invalid_password(self): def test_invalid_username(self): headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration tests for Auth", - 'Authorization': - 'Basic ' + - base64.b64encode('wrong_user:P@ssw0rd'.encode('utf-8')). - decode('utf-8') - } + "Content-Type": "application/json", + "TabPy-Client": "Integration tests for Auth", + "Authorization": "Basic " + + base64.b64encode("wrong_user:P@ssw0rd".encode("utf-8")).decode("utf-8"), + } conn = self._get_connection() conn.request("POST", "/evaluate", self.payload, headers) @@ -60,13 +55,11 @@ def test_invalid_username(self): def test_valid_credentials(self): headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration tests for Auth", - 'Authorization': - 'Basic ' + - base64.b64encode('user1:P@ssw0rd'.encode('utf-8')). - decode('utf-8') - } + "Content-Type": "application/json", + "TabPy-Client": "Integration tests for Auth", + "Authorization": "Basic " + + base64.b64encode("user1:P@ssw0rd".encode("utf-8")).decode("utf-8"), + } conn = self._get_connection() conn.request("POST", "/evaluate", self.payload, headers) diff --git a/tests/integration/test_custom_evaluate_timeout.py b/tests/integration/test_custom_evaluate_timeout.py index dafb0f3b..04bbb655 100644 --- a/tests/integration/test_custom_evaluate_timeout.py +++ b/tests/integration/test_custom_evaluate_timeout.py @@ -3,33 +3,31 @@ class TestCustomEvaluateTimeout(integ_test_base.IntegTestBase): def _get_evaluate_timeout(self) -> str: - return '5' + return "5" def test_custom_evaluate_timeout_with_script(self): - payload = ( - ''' + payload = """ { "data": { "_arg1": 1 }, "script": "import time\\nwhile True:\\n time.sleep(1)\\nreturn 1" } - ''') + """ headers = { - 'Content-Type': - "application/json", - 'TabPy-Client': - "Integration test for testing custom evaluate timeouts with " - "scripts." + "Content-Type": "application/json", + "TabPy-Client": "Integration test for testing custom evaluate timeouts with " + "scripts.", } conn = self._get_connection() - conn.request('POST', '/evaluate', payload, headers) + conn.request("POST", "/evaluate", payload, headers) res = conn.getresponse() - actual_error_message = res.read().decode('utf-8') + actual_error_message = res.read().decode("utf-8") self.assertEqual( '{"message": ' '"User defined script timed out. Timeout is set to 5.0 s.", ' '"info": {}}', - actual_error_message) + actual_error_message, + ) self.assertEqual(408, res.status) diff --git a/tests/integration/test_deploy_and_evaluate_model.py b/tests/integration/test_deploy_and_evaluate_model.py index 30274ebe..90e73c6c 100644 --- a/tests/integration/test_deploy_and_evaluate_model.py +++ b/tests/integration/test_deploy_and_evaluate_model.py @@ -5,10 +5,10 @@ class TestDeployAndEvaluateModel(integ_test_base.IntegTestBase): def _get_config_file_name(self) -> str: - return './tests/integration/resources/deploy_and_evaluate_model.conf' + return "./tests/integration/resources/deploy_and_evaluate_model.conf" def _get_port(self) -> str: - return '9008' + return "9008" def test_deploy_and_evaluate_model(self): # Uncomment the following line to preserve @@ -18,12 +18,11 @@ def test_deploy_and_evaluate_model(self): self.deploy_models(self._get_username(), self._get_password()) - payload = ( - '''{ + payload = """{ "data": { "_arg1": ["happy", "sad", "neutral"] }, "script": "return tabpy.query('Sentiment Analysis',_arg1)['response']" - }''') + }""" conn = self._get_connection() conn.request("POST", "/evaluate", payload) diff --git a/tests/integration/test_deploy_and_evaluate_model_ssl.py b/tests/integration/test_deploy_and_evaluate_model_ssl.py index bb928a75..dd5c7d8f 100755 --- a/tests/integration/test_deploy_and_evaluate_model_ssl.py +++ b/tests/integration/test_deploy_and_evaluate_model_ssl.py @@ -6,16 +6,16 @@ class TestDeployAndEvaluateModelSSL(integ_test_base.IntegTestBase): def _get_port(self): - return '9005' + return "9005" def _get_transfer_protocol(self) -> str: - return 'https' + return "https" def _get_certificate_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.crt' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.crt" def _get_key_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.key' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.key" def test_deploy_and_evaluate_model_ssl(self): # Uncomment the following line to preserve @@ -25,12 +25,11 @@ def test_deploy_and_evaluate_model_ssl(self): self.deploy_models(self._get_username(), self._get_password()) - payload = ( - '''{ + payload = """{ "data": { "_arg1": ["happy", "sad", "neutral"] }, "script": "return tabpy.query('Sentiment%20Analysis',_arg1)['response']" - }''') + }""" session = requests.Session() # Do not verify servers' cert to be signed by trusted CA @@ -38,8 +37,9 @@ def test_deploy_and_evaluate_model_ssl(self): # Do not warn about insecure request requests.packages.urllib3.disable_warnings() response = session.post( - f'{self._get_transfer_protocol()}://' - f'localhost:{self._get_port()}/evaluate', - data=payload) + f"{self._get_transfer_protocol()}://" + f"localhost:{self._get_port()}/evaluate", + data=payload, + ) self.assertEqual(200, response.status_code) diff --git a/tests/integration/test_deploy_model_ssl_off_auth_off.py b/tests/integration/test_deploy_model_ssl_off_auth_off.py index 1aa0d097..e532ae49 100644 --- a/tests/integration/test_deploy_model_ssl_off_auth_off.py +++ b/tests/integration/test_deploy_model_ssl_off_auth_off.py @@ -14,9 +14,9 @@ def test_deploy_ssl_off_auth_off(self): conn = self._get_connection() - models = ['PCA', 'Sentiment%20Analysis', "ttest", "anova"] + models = ["PCA", "Sentiment%20Analysis", "ttest", "anova"] for m in models: - conn.request("GET", f'/endpoints/{m}') + conn.request("GET", f"/endpoints/{m}") m_request = conn.getresponse() self.assertEqual(200, m_request.status) m_request.read() diff --git a/tests/integration/test_deploy_model_ssl_off_auth_on.py b/tests/integration/test_deploy_model_ssl_off_auth_on.py index 7092b4d8..00040934 100644 --- a/tests/integration/test_deploy_model_ssl_off_auth_on.py +++ b/tests/integration/test_deploy_model_ssl_off_auth_on.py @@ -6,25 +6,23 @@ class TestDeployModelSSLOffAuthOn(integ_test_base.IntegTestBase): def _get_pwd_file(self) -> str: - return './tests/integration/resources/pwdfile.txt' + return "./tests/integration/resources/pwdfile.txt" def test_deploy_ssl_off_auth_on(self): self.deploy_models(self._get_username(), self._get_password()) headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration test for deploying models with auth", - 'Authorization': - 'Basic ' + - base64.b64encode('user1:P@ssw0rd'. - encode('utf-8')).decode('utf-8') + "Content-Type": "application/json", + "TabPy-Client": "Integration test for deploying models with auth", + "Authorization": "Basic " + + base64.b64encode("user1:P@ssw0rd".encode("utf-8")).decode("utf-8"), } conn = self._get_connection() - models = ['PCA', 'Sentiment%20Analysis', "ttest", "anova"] + models = ["PCA", "Sentiment%20Analysis", "ttest", "anova"] for m in models: - conn.request("GET", f'/endpoints/{m}', headers=headers) + conn.request("GET", f"/endpoints/{m}", headers=headers) m_request = conn.getresponse() self.assertEqual(200, m_request.status) m_request.read() diff --git a/tests/integration/test_deploy_model_ssl_on_auth_off.py b/tests/integration/test_deploy_model_ssl_on_auth_off.py index 584ce648..87596f20 100644 --- a/tests/integration/test_deploy_model_ssl_on_auth_off.py +++ b/tests/integration/test_deploy_model_ssl_on_auth_off.py @@ -6,13 +6,13 @@ class TestDeployModelSSLOnAuthOff(integ_test_base.IntegTestBase): def _get_transfer_protocol(self) -> str: - return 'https' + return "https" def _get_certificate_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.crt' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.crt" def _get_key_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.key' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.key" def test_deploy_ssl_on_auth_off(self): self.deploy_models(self._get_username(), self._get_password()) @@ -23,8 +23,10 @@ def test_deploy_ssl_on_auth_off(self): # Do not warn about insecure request requests.packages.urllib3.disable_warnings() - models = ['PCA', 'Sentiment%20Analysis', "ttest", "anova"] + models = ["PCA", "Sentiment%20Analysis", "ttest", "anova"] for m in models: - m_response = session.get(url=f'{self._get_transfer_protocol()}://' - f'localhost:9004/endpoints/{m}') + m_response = session.get( + url=f"{self._get_transfer_protocol()}://" + f"localhost:9004/endpoints/{m}" + ) self.assertEqual(200, m_response.status_code) diff --git a/tests/integration/test_deploy_model_ssl_on_auth_on.py b/tests/integration/test_deploy_model_ssl_on_auth_on.py index 36739252..19e17730 100644 --- a/tests/integration/test_deploy_model_ssl_on_auth_on.py +++ b/tests/integration/test_deploy_model_ssl_on_auth_on.py @@ -6,16 +6,16 @@ class TestDeployModelSSLOnAuthOn(integ_test_base.IntegTestBase): def _get_transfer_protocol(self) -> str: - return 'https' + return "https" def _get_certificate_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.crt' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.crt" def _get_key_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.key' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.key" def _get_pwd_file(self) -> str: - return './tests/integration/resources/pwdfile.txt' + return "./tests/integration/resources/pwdfile.txt" def test_deploy_ssl_on_auth_on(self): # Uncomment the following line to preserve @@ -26,10 +26,10 @@ def test_deploy_ssl_on_auth_on(self): self.deploy_models(self._get_username(), self._get_password()) headers = { - 'Content-Type': "application/json", - 'TabPy-Client': "Integration test for deploying models with auth", - 'Authorization': 'Basic ' + - base64.b64encode('user1:P@ssw0rd'.encode('utf-8')).decode('utf-8') + "Content-Type": "application/json", + "TabPy-Client": "Integration test for deploying models with auth", + "Authorization": "Basic " + + base64.b64encode("user1:P@ssw0rd".encode("utf-8")).decode("utf-8"), } session = requests.Session() @@ -38,9 +38,11 @@ def test_deploy_ssl_on_auth_on(self): # Do not warn about insecure request requests.packages.urllib3.disable_warnings() - models = ['PCA', 'Sentiment%20Analysis', "ttest", "anova"] + models = ["PCA", "Sentiment%20Analysis", "ttest", "anova"] for m in models: - m_response = session.get(url=f'{self._get_transfer_protocol()}://' - f'localhost:9004/endpoints/{m}', - headers=headers) + m_response = session.get( + url=f"{self._get_transfer_protocol()}://" + f"localhost:9004/endpoints/{m}", + headers=headers, + ) self.assertEqual(200, m_response.status_code) diff --git a/tests/integration/test_url.py b/tests/integration/test_url.py index 74785d88..070296cb 100755 --- a/tests/integration/test_url.py +++ b/tests/integration/test_url.py @@ -1,6 +1,6 @@ -''' +""" All other misc. URL-related integration tests. -''' +""" import integ_test_base diff --git a/tests/integration/test_url_ssl.py b/tests/integration/test_url_ssl.py index 0f300b46..4ded4d5c 100755 --- a/tests/integration/test_url_ssl.py +++ b/tests/integration/test_url_ssl.py @@ -1,7 +1,7 @@ -''' +""" All other misc. URL-related integration tests for when SSL is turned on for TabPy. -''' +""" import integ_test_base import requests @@ -9,16 +9,16 @@ class TestURL_SSL(integ_test_base.IntegTestBase): def _get_port(self): - return '9005' + return "9005" def _get_transfer_protocol(self) -> str: - return 'https' + return "https" def _get_certificate_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.crt' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.crt" def _get_key_file_name(self) -> str: - return './tests/integration/resources/2019_04_24_to_3018_08_25.key' + return "./tests/integration/resources/2019_04_24_to_3018_08_25.key" def test_notexistent_url(self): session = requests.Session() @@ -26,7 +26,6 @@ def test_notexistent_url(self): session.verify = False # Do not warn about insecure request requests.packages.urllib3.disable_warnings() - response = session.get( - url=f'https://localhost:{self._get_port()}/unicorn') + response = session.get(url=f"https://localhost:{self._get_port()}/unicorn") self.assertEqual(404, response.status_code) diff --git a/tests/unit/server_tests/test_config.py b/tests/unit/server_tests/test_config.py index d665657b..67be3c71 100644 --- a/tests/unit/server_tests/test_config.py +++ b/tests/unit/server_tests/test_config.py @@ -11,34 +11,44 @@ class TestConfigEnvironmentCalls(unittest.TestCase): def test_config_file_does_not_exist(self): - app = TabPyApp('/folder_does_not_exit/file_does_not_exist.conf') - - self.assertEqual(app.settings['port'], 9004) - self.assertEqual(app.settings['server_version'], - open('tabpy/VERSION').read().strip()) - self.assertEqual(app.settings['transfer_protocol'], 'http') - self.assertTrue('certificate_file' not in app.settings) - self.assertTrue('key_file' not in app.settings) - self.assertEqual(app.settings['log_request_context'], False) - self.assertEqual(app.settings['evaluate_timeout'], 30) - - @patch('tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace(config=None)) - @patch('tabpy.tabpy_server.app.app.TabPyState') - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy.tabpy_server.app.app.os') - def test_no_config_file(self, mock_os, - mock_path_exists, mock_psws, - mock_management_util, mock_tabpy_state, - mock_parse_arguments): + app = TabPyApp("/folder_does_not_exit/file_does_not_exist.conf") + + self.assertEqual(app.settings["port"], 9004) + self.assertEqual( + app.settings["server_version"], open("tabpy/VERSION").read().strip() + ) + self.assertEqual(app.settings["transfer_protocol"], "http") + self.assertTrue("certificate_file" not in app.settings) + self.assertTrue("key_file" not in app.settings) + self.assertEqual(app.settings["log_request_context"], False) + self.assertEqual(app.settings["evaluate_timeout"], 30) + + @patch( + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) + @patch("tabpy.tabpy_server.app.app.TabPyState") + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) + @patch("tabpy.tabpy_server.app.app.os") + def test_no_config_file( + self, + mock_os, + mock_path_exists, + mock_psws, + mock_management_util, + mock_tabpy_state, + mock_parse_arguments, + ): pkg_path = os.path.dirname(tabpy.__file__) - obj_path = os.path.join(pkg_path, 'tmp', 'query_objects') - state_path = os.path.join(pkg_path, 'tabpy_server') + obj_path = os.path.join(pkg_path, "tmp", "query_objects") + state_path = os.path.join(pkg_path, "tabpy_server") mock_os.environ = { - 'TABPY_PORT': '9004', 'TABPY_QUERY_OBJECT_PATH': obj_path, - 'TABPY_STATE_PATH': state_path} + "TABPY_PORT": "9004", + "TABPY_QUERY_OBJECT_PATH": obj_path, + "TABPY_STATE_PATH": state_path, + } TabPyApp(None) @@ -48,18 +58,24 @@ def test_no_config_file(self, mock_os, self.assertTrue(len(mock_management_util.mock_calls) > 0) mock_os.makedirs.assert_not_called() - @patch('tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace(config=None)) - @patch('tabpy.tabpy_server.app.app.TabPyState') - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=False) - @patch('tabpy.tabpy_server.app.app.os') - def test_no_state_ini_file_or_state_dir(self, mock_os, - mock_path_exists, mock_psws, - mock_management_util, - mock_tabpy_state, - mock_parse_arguments): + @patch( + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) + @patch("tabpy.tabpy_server.app.app.TabPyState") + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=False) + @patch("tabpy.tabpy_server.app.app.os") + def test_no_state_ini_file_or_state_dir( + self, + mock_os, + mock_path_exists, + mock_psws, + mock_management_util, + mock_tabpy_state, + mock_parse_arguments, + ): TabPyApp(None) self.assertEqual(len(mock_os.makedirs.mock_calls), 1) @@ -72,83 +88,92 @@ def tearDown(self): os.remove(self.config_file.name) self.config_file = None - @patch('tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments') - @patch('tabpy.tabpy_server.app.app.TabPyState') - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy.tabpy_server.app.app.os') - def test_config_file_present(self, mock_os, mock_path_exists, - mock_psws, mock_management_util, - mock_tabpy_state, mock_parse_arguments): + @patch("tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments") + @patch("tabpy.tabpy_server.app.app.TabPyState") + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.PythonServiceHandler") + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) + @patch("tabpy.tabpy_server.app.app.os") + def test_config_file_present( + self, + mock_os, + mock_path_exists, + mock_psws, + mock_management_util, + mock_tabpy_state, + mock_parse_arguments, + ): self.assertTrue(self.config_file is not None) config_file = self.config_file - config_file.write('[TabPy]\n' - 'TABPY_QUERY_OBJECT_PATH = foo\n' - 'TABPY_STATE_PATH = bar\n'.encode()) + config_file.write( + "[TabPy]\n" + "TABPY_QUERY_OBJECT_PATH = foo\n" + "TABPY_STATE_PATH = bar\n".encode() + ) config_file.close() mock_parse_arguments.return_value = Namespace(config=config_file.name) - mock_os.path.realpath.return_value = 'bar' - mock_os.environ = {'TABPY_PORT': '1234'} + mock_os.path.realpath.return_value = "bar" + mock_os.environ = {"TABPY_PORT": "1234"} app = TabPyApp(config_file.name) - self.assertEqual(app.settings['port'], '1234') - self.assertEqual(app.settings['server_version'], - open('tabpy/VERSION').read().strip()) - self.assertEqual(app.settings['upload_dir'], 'foo') - self.assertEqual(app.settings['state_file_path'], 'bar') - self.assertEqual(app.settings['transfer_protocol'], 'http') - self.assertTrue('certificate_file' not in app.settings) - self.assertTrue('key_file' not in app.settings) - self.assertEqual(app.settings['log_request_context'], False) - self.assertEqual(app.settings['evaluate_timeout'], 30) - - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.TabPyState') - def test_custom_evaluate_timeout_valid(self, mock_state, - mock_get_state_from_file, - mock_path_exists): + self.assertEqual(app.settings["port"], "1234") + self.assertEqual( + app.settings["server_version"], open("tabpy/VERSION").read().strip() + ) + self.assertEqual(app.settings["upload_dir"], "foo") + self.assertEqual(app.settings["state_file_path"], "bar") + self.assertEqual(app.settings["transfer_protocol"], "http") + self.assertTrue("certificate_file" not in app.settings) + self.assertTrue("key_file" not in app.settings) + self.assertEqual(app.settings["log_request_context"], False) + self.assertEqual(app.settings["evaluate_timeout"], 30) + + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.TabPyState") + def test_custom_evaluate_timeout_valid( + self, mock_state, mock_get_state_from_file, mock_path_exists + ): self.assertTrue(self.config_file is not None) config_file = self.config_file - config_file.write('[TabPy]\n' - 'TABPY_EVALUATE_TIMEOUT = 1996'.encode()) + config_file.write("[TabPy]\n" "TABPY_EVALUATE_TIMEOUT = 1996".encode()) config_file.close() app = TabPyApp(self.config_file.name) - self.assertEqual(app.settings['evaluate_timeout'], 1996.0) - - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.TabPyState') - def test_custom_evaluate_timeout_invalid(self, mock_state, - mock_get_state_from_file, - mock_path_exists): + self.assertEqual(app.settings["evaluate_timeout"], 1996.0) + + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.TabPyState") + def test_custom_evaluate_timeout_invalid( + self, mock_state, mock_get_state_from_file, mock_path_exists + ): self.assertTrue(self.config_file is not None) config_file = self.config_file - config_file.write('[TabPy]\n' - 'TABPY_EVALUATE_TIMEOUT = "im not a float"'.encode()) + config_file.write( + "[TabPy]\n" 'TABPY_EVALUATE_TIMEOUT = "im not a float"'.encode() + ) config_file.close() app = TabPyApp(self.config_file.name) - self.assertEqual(app.settings['evaluate_timeout'], 30.0) - - @patch('tabpy.tabpy_server.app.app.os') - @patch('tabpy.tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy.tabpy_server.app.app._get_state_from_file') - @patch('tabpy.tabpy_server.app.app.TabPyState') - def test_env_variables_in_config(self, mock_state, mock_get_state, - mock_path_exists, mock_os): - mock_os.environ = {'foo': 'baz'} + self.assertEqual(app.settings["evaluate_timeout"], 30.0) + + @patch("tabpy.tabpy_server.app.app.os") + @patch("tabpy.tabpy_server.app.app.os.path.exists", return_value=True) + @patch("tabpy.tabpy_server.app.app._get_state_from_file") + @patch("tabpy.tabpy_server.app.app.TabPyState") + def test_env_variables_in_config( + self, mock_state, mock_get_state, mock_path_exists, mock_os + ): + mock_os.environ = {"foo": "baz"} config_file = self.config_file - config_file.write('[TabPy]\n' - 'TABPY_PORT = %(foo)sbar'.encode()) + config_file.write("[TabPy]\n" "TABPY_PORT = %(foo)sbar".encode()) config_file.close() app = TabPyApp(self.config_file.name) - self.assertEqual(app.settings['port'], 'bazbar') + self.assertEqual(app.settings["port"], "bazbar") class TestTransferProtocolValidation(unittest.TestCase): @@ -172,116 +197,128 @@ def __init__(self, *args, **kwargs): self.fp = None def setUp(self): - self.fp = NamedTemporaryFile(mode='w+t', delete=False) + self.fp = NamedTemporaryFile(mode="w+t", delete=False) def tearDown(self): os.remove(self.fp.name) self.fp = None def test_invalid_protocol(self): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = gopher") + self.fp.write("[TabPy]\n" "TABPY_TRANSFER_PROTOCOL = gopher") self.fp.close() - self.assertTabPyAppRaisesRuntimeError( - 'Unsupported transfer protocol: gopher') + self.assertTabPyAppRaisesRuntimeError("Unsupported transfer protocol: gopher") def test_http(self): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = http") + self.fp.write("[TabPy]\n" "TABPY_TRANSFER_PROTOCOL = http") self.fp.close() app = TabPyApp(self.fp.name) - self.assertEqual(app.settings['transfer_protocol'], 'http') + self.assertEqual(app.settings["transfer_protocol"], "http") def test_https_without_cert_and_key(self): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https") + self.fp.write("[TabPy]\n" "TABPY_TRANSFER_PROTOCOL = https") self.fp.close() - self.assertTabPyAppRaisesRuntimeError('Error using HTTPS: The paramete' - 'r(s) TABPY_CERTIFICATE_FILE and' - ' TABPY_KEY_FILE must be set.') + self.assertTabPyAppRaisesRuntimeError( + "Error using HTTPS: The paramete" + "r(s) TABPY_CERTIFICATE_FILE and" + " TABPY_KEY_FILE must be set." + ) def test_https_without_cert(self): self.fp.write( - "[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https\n" - "TABPY_KEY_FILE = foo") + "[TabPy]\n" "TABPY_TRANSFER_PROTOCOL = https\n" "TABPY_KEY_FILE = foo" + ) self.fp.close() self.assertTabPyAppRaisesRuntimeError( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE must ' - 'be set.') + "Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE must " "be set." + ) def test_https_without_key(self): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https\n" - "TABPY_CERTIFICATE_FILE = foo") + self.fp.write( + "[TabPy]\n" + "TABPY_TRANSFER_PROTOCOL = https\n" + "TABPY_CERTIFICATE_FILE = foo" + ) self.fp.close() self.assertTabPyAppRaisesRuntimeError( - 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must be set.') + "Error using HTTPS: The parameter(s) TABPY_KEY_FILE must be set." + ) - @patch('tabpy.tabpy_server.app.app.os.path') + @patch("tabpy.tabpy_server.app.app.os.path") def test_https_cert_and_key_file_not_found(self, mock_path): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https\n" - "TABPY_CERTIFICATE_FILE = foo\n" - "TABPY_KEY_FILE = bar") + self.fp.write( + "[TabPy]\n" + "TABPY_TRANSFER_PROTOCOL = https\n" + "TABPY_CERTIFICATE_FILE = foo\n" + "TABPY_KEY_FILE = bar" + ) self.fp.close() - mock_path.isfile.side_effect = lambda x: self.mock_isfile( - x, {self.fp.name}) + mock_path.isfile.side_effect = lambda x: self.mock_isfile(x, {self.fp.name}) self.assertTabPyAppRaisesRuntimeError( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' - 'TABPY_KEY_FILE must point to an existing file.') + "Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and " + "TABPY_KEY_FILE must point to an existing file." + ) - @patch('tabpy.tabpy_server.app.app.os.path') + @patch("tabpy.tabpy_server.app.app.os.path") def test_https_cert_file_not_found(self, mock_path): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https\n" - "TABPY_CERTIFICATE_FILE = foo\n" - "TABPY_KEY_FILE = bar") + self.fp.write( + "[TabPy]\n" + "TABPY_TRANSFER_PROTOCOL = https\n" + "TABPY_CERTIFICATE_FILE = foo\n" + "TABPY_KEY_FILE = bar" + ) self.fp.close() mock_path.isfile.side_effect = lambda x: self.mock_isfile( - x, {self.fp.name, 'bar'}) + x, {self.fp.name, "bar"} + ) self.assertTabPyAppRaisesRuntimeError( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' - 'must point to an existing file.') + "Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE " + "must point to an existing file." + ) - @patch('tabpy.tabpy_server.app.app.os.path') + @patch("tabpy.tabpy_server.app.app.os.path") def test_https_key_file_not_found(self, mock_path): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = https\n" - "TABPY_CERTIFICATE_FILE = foo\n" - "TABPY_KEY_FILE = bar") + self.fp.write( + "[TabPy]\n" + "TABPY_TRANSFER_PROTOCOL = https\n" + "TABPY_CERTIFICATE_FILE = foo\n" + "TABPY_KEY_FILE = bar" + ) self.fp.close() mock_path.isfile.side_effect = lambda x: self.mock_isfile( - x, {self.fp.name, 'foo'}) + x, {self.fp.name, "foo"} + ) self.assertTabPyAppRaisesRuntimeError( - 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE ' - 'must point to an existing file.') + "Error using HTTPS: The parameter(s) TABPY_KEY_FILE " + "must point to an existing file." + ) - @patch('tabpy.tabpy_server.app.app.os.path.isfile', return_value=True) - @patch('tabpy.tabpy_server.app.util.validate_cert') + @patch("tabpy.tabpy_server.app.app.os.path.isfile", return_value=True) + @patch("tabpy.tabpy_server.app.util.validate_cert") def test_https_success(self, mock_isfile, mock_validate_cert): - self.fp.write("[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = HtTpS\n" - "TABPY_CERTIFICATE_FILE = foo\n" - "TABPY_KEY_FILE = bar") + self.fp.write( + "[TabPy]\n" + "TABPY_TRANSFER_PROTOCOL = HtTpS\n" + "TABPY_CERTIFICATE_FILE = foo\n" + "TABPY_KEY_FILE = bar" + ) self.fp.close() app = TabPyApp(self.fp.name) - self.assertEqual(app.settings['transfer_protocol'], 'https') - self.assertEqual(app.settings['certificate_file'], 'foo') - self.assertEqual(app.settings['key_file'], 'bar') + self.assertEqual(app.settings["transfer_protocol"], "https") + self.assertEqual(app.settings["certificate_file"], "foo") + self.assertEqual(app.settings["key_file"], "bar") class TestCertificateValidation(unittest.TestCase): @@ -292,21 +329,24 @@ def assertValidateCertRaisesRuntimeError(self, expected_message, path): def __init__(self, *args, **kwargs): super(TestCertificateValidation, self).__init__(*args, **kwargs) - self.resources_path = os.path.join( - os.path.dirname(__file__), 'resources') + self.resources_path = os.path.join(os.path.dirname(__file__), "resources") def test_expired_cert(self): - path = os.path.join(self.resources_path, 'expired.crt') - message = ('Error using HTTPS: The certificate provided expired ' - 'on 2018-08-18 19:47:18.') + path = os.path.join(self.resources_path, "expired.crt") + message = ( + "Error using HTTPS: The certificate provided expired " + "on 2018-08-18 19:47:18." + ) self.assertValidateCertRaisesRuntimeError(message, path) def test_future_cert(self): - path = os.path.join(self.resources_path, 'future.crt') - message = ('Error using HTTPS: The certificate provided is not valid ' - 'until 3001-01-01 00:00:00.') + path = os.path.join(self.resources_path, "future.crt") + message = ( + "Error using HTTPS: The certificate provided is not valid " + "until 3001-01-01 00:00:00." + ) self.assertValidateCertRaisesRuntimeError(message, path) def test_valid_cert(self): - path = os.path.join(self.resources_path, 'valid.crt') + path = os.path.join(self.resources_path, "valid.crt") validate_cert(path) diff --git a/tests/unit/server_tests/test_endpoint_handler.py b/tests/unit/server_tests/test_endpoint_handler.py index f98135da..83f98628 100755 --- a/tests/unit/server_tests/test_endpoint_handler.py +++ b/tests/unit/server_tests/test_endpoint_handler.py @@ -13,46 +13,51 @@ class TestEndpointHandlerWithAuth(AsyncHTTPTestCase): @classmethod def setUpClass(cls): cls.patcher = patch( - 'tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace( - config=None)) + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) cls.patcher.start() - prefix = '__TestEndpointHandlerWithAuth_' + prefix = "__TestEndpointHandlerWithAuth_" # create password file cls.pwd_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.txt', delete=False) - username = 'username' - password = 'password' - cls.pwd_file.write(f'{username} {hash_password(username, password)}') + mode="w+t", prefix=prefix, suffix=".txt", delete=False + ) + username = "username" + password = "password" + cls.pwd_file.write(f"{username} {hash_password(username, password)}") cls.pwd_file.close() # create state.ini dir and file cls.state_dir = tempfile.mkdtemp(prefix=prefix) - cls.state_file = open(os.path.join(cls.state_dir, 'state.ini'), 'w+') - cls.state_file.write('[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + cls.state_file = open(os.path.join(cls.state_dir, "state.ini"), "w+") + cls.state_file.write( + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) cls.state_file.close() # create config file cls.config_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.conf', delete=False) + mode="w+t", prefix=prefix, suffix=".conf", delete=False + ) cls.config_file.write( - '[TabPy]\n' - f'TABPY_PWD_FILE = {cls.pwd_file.name}\n' - f'TABPY_STATE_PATH = {cls.state_dir}') + "[TabPy]\n" + f"TABPY_PWD_FILE = {cls.pwd_file.name}\n" + f"TABPY_STATE_PATH = {cls.state_dir}" + ) cls.config_file.close() @classmethod @@ -68,41 +73,47 @@ def get_app(self): return self.app._create_tornado_web_app() def test_no_creds_required_auth_fails(self): - response = self.fetch('/endpoints/anything') + response = self.fetch("/endpoints/anything") self.assertEqual(401, response.code) def test_invalid_creds_fails(self): response = self.fetch( - '/endpoints/anything', - method='GET', + "/endpoints/anything", + method="GET", headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('user:wrong_password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("user:wrong_password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(401, response.code) def test_valid_creds_pass(self): response = self.fetch( - '/endpoints/', - method='GET', + "/endpoints/", + method="GET", headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(200, response.code) def test_valid_creds_unknown_endpoint_fails(self): response = self.fetch( - '/endpoints/unknown_endpoint', - method='GET', + "/endpoints/unknown_endpoint", + method="GET", headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(404, response.code) diff --git a/tests/unit/server_tests/test_endpoints_handler.py b/tests/unit/server_tests/test_endpoints_handler.py index dcb422eb..61a78c92 100755 --- a/tests/unit/server_tests/test_endpoints_handler.py +++ b/tests/unit/server_tests/test_endpoints_handler.py @@ -13,46 +13,51 @@ class TestEndpointsHandlerWithAuth(AsyncHTTPTestCase): @classmethod def setUpClass(cls): cls.patcher = patch( - 'tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace( - config=None)) + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) cls.patcher.start() - prefix = '__TestEndpointsHandlerWithAuth_' + prefix = "__TestEndpointsHandlerWithAuth_" # create password file cls.pwd_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.txt', delete=False) - username = 'username' - password = 'password' - cls.pwd_file.write(f'{username} {hash_password(username, password)}') + mode="w+t", prefix=prefix, suffix=".txt", delete=False + ) + username = "username" + password = "password" + cls.pwd_file.write(f"{username} {hash_password(username, password)}") cls.pwd_file.close() # create state.ini dir and file cls.state_dir = tempfile.mkdtemp(prefix=prefix) - cls.state_file = open(os.path.join(cls.state_dir, 'state.ini'), 'w+') - cls.state_file.write('[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + cls.state_file = open(os.path.join(cls.state_dir, "state.ini"), "w+") + cls.state_file.write( + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) cls.state_file.close() # create config file cls.config_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.conf', delete=False) + mode="w+t", prefix=prefix, suffix=".conf", delete=False + ) cls.config_file.write( - '[TabPy]\n' - f'TABPY_PWD_FILE = {cls.pwd_file.name}\n' - f'TABPY_STATE_PATH = {cls.state_dir}') + "[TabPy]\n" + f"TABPY_PWD_FILE = {cls.pwd_file.name}\n" + f"TABPY_STATE_PATH = {cls.state_dir}" + ) cls.config_file.close() @classmethod @@ -68,29 +73,33 @@ def get_app(self): return self.app._create_tornado_web_app() def test_no_creds_required_auth_fails(self): - response = self.fetch('/endpoints') + response = self.fetch("/endpoints") self.assertEqual(401, response.code) def test_invalid_creds_fails(self): response = self.fetch( - '/endpoints', - method='GET', + "/endpoints", + method="GET", headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('user:wrong_password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("user:wrong_password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(401, response.code) def test_valid_creds_pass(self): response = self.fetch( - '/endpoints', - method='GET', + "/endpoints", + method="GET", headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(200, response.code) diff --git a/tests/unit/server_tests/test_evaluation_plane_handler.py b/tests/unit/server_tests/test_evaluation_plane_handler.py index 45bf6962..16979764 100755 --- a/tests/unit/server_tests/test_evaluation_plane_handler.py +++ b/tests/unit/server_tests/test_evaluation_plane_handler.py @@ -13,66 +13,75 @@ class TestEvaluationPlainHandlerWithAuth(AsyncHTTPTestCase): @classmethod def setUpClass(cls): cls.patcher = patch( - 'tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace( - config=None)) + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) cls.patcher.start() - prefix = '__TestEvaluationPlainHandlerWithAuth_' + prefix = "__TestEvaluationPlainHandlerWithAuth_" # create password file cls.pwd_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.txt', delete=False) - username = 'username' - password = 'password' - cls.pwd_file.write(f'{username} {hash_password(username, password)}\n') + mode="w+t", prefix=prefix, suffix=".txt", delete=False + ) + username = "username" + password = "password" + cls.pwd_file.write(f"{username} {hash_password(username, password)}\n") cls.pwd_file.close() # create state.ini dir and file cls.state_dir = tempfile.mkdtemp(prefix=prefix) - cls.state_file = open(os.path.join(cls.state_dir, 'state.ini'), 'w+') - cls.state_file.write('[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + cls.state_file = open(os.path.join(cls.state_dir, "state.ini"), "w+") + cls.state_file.write( + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) cls.state_file.close() # create config file cls.config_file = tempfile.NamedTemporaryFile( - mode='w+t', prefix=prefix, suffix='.conf', delete=False) + mode="w+t", prefix=prefix, suffix=".conf", delete=False + ) cls.config_file.write( - '[TabPy]\n' - f'TABPY_PWD_FILE = {cls.pwd_file.name}\n' - f'TABPY_STATE_PATH = {cls.state_dir}') + "[TabPy]\n" + f"TABPY_PWD_FILE = {cls.pwd_file.name}\n" + f"TABPY_STATE_PATH = {cls.state_dir}" + ) cls.config_file.close() - cls.script =\ - '{"data":{"_arg1":[2,3],"_arg2":[3,-1]},'\ - '"script":"res=[]\\nfor i in range(len(_arg1)):\\n '\ + cls.script = ( + '{"data":{"_arg1":[2,3],"_arg2":[3,-1]},' + '"script":"res=[]\\nfor i in range(len(_arg1)):\\n ' 'res.append(_arg1[i] * _arg2[i])\\nreturn res"}' + ) - cls.script_not_present =\ - '{"data":{"_arg1":[2,3],"_arg2":[3,-1]},'\ - '"":"res=[]\\nfor i in range(len(_arg1)):\\n '\ + cls.script_not_present = ( + '{"data":{"_arg1":[2,3],"_arg2":[3,-1]},' + '"":"res=[]\\nfor i in range(len(_arg1)):\\n ' 'res.append(_arg1[i] * _arg2[i])\\nreturn res"}' + ) - cls.args_not_present =\ - '{"script":"res=[]\\nfor i in range(len(_arg1)):\\n '\ + cls.args_not_present = ( + '{"script":"res=[]\\nfor i in range(len(_arg1)):\\n ' 'res.append(_arg1[i] * _arg2[i])\\nreturn res"}' + ) - cls.args_not_sequential =\ - '{"data":{"_arg1":[2,3],"_arg3":[3,-1]},'\ - '"script":"res=[]\\nfor i in range(len(_arg1)):\\n '\ + cls.args_not_sequential = ( + '{"data":{"_arg1":[2,3],"_arg3":[3,-1]},' + '"script":"res=[]\\nfor i in range(len(_arg1)):\\n ' 'res.append(_arg1[i] * _arg3[i])\\nreturn res"}' + ) @classmethod def tearDownClass(cls): @@ -87,77 +96,84 @@ def get_app(self): return self.app._create_tornado_web_app() def test_no_creds_required_auth_fails(self): - response = self.fetch( - '/evaluate', - method='POST', - body=self.script) + response = self.fetch("/evaluate", method="POST", body=self.script) self.assertEqual(401, response.code) def test_invalid_creds_fails(self): response = self.fetch( - '/evaluate', - method='POST', + "/evaluate", + method="POST", body=self.script, headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('user:wrong_password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("user:wrong_password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(401, response.code) def test_valid_creds_pass(self): response = self.fetch( - '/evaluate', - method='POST', + "/evaluate", + method="POST", body=self.script, headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(200, response.code) def test_null_request(self): - response = self.fetch('') + response = self.fetch("") self.assertEqual(404, response.code) def test_script_not_present(self): response = self.fetch( - '/evaluate', - method='POST', + "/evaluate", + method="POST", body=self.script_not_present, headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(400, response.code) def test_arguments_not_present(self): response = self.fetch( - '/evaluate', - method='POST', + "/evaluate", + method="POST", body=self.args_not_present, headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(500, response.code) def test_arguments_not_sequential(self): response = self.fetch( - '/evaluate', - method='POST', + "/evaluate", + method="POST", body=self.args_not_sequential, headers={ - 'Authorization': 'Basic {}'. - format( - base64.b64encode('username:password'.encode('utf-8')). - decode('utf-8')) - }) + "Authorization": "Basic {}".format( + base64.b64encode("username:password".encode("utf-8")).decode( + "utf-8" + ) + ) + }, + ) self.assertEqual(400, response.code) diff --git a/tests/unit/server_tests/test_pwd_file.py b/tests/unit/server_tests/test_pwd_file.py index 596ab9e2..b7c64f96 100755 --- a/tests/unit/server_tests/test_pwd_file.py +++ b/tests/unit/server_tests/test_pwd_file.py @@ -7,9 +7,9 @@ class TestPasswordFile(unittest.TestCase): def setUp(self): - self.config_file = NamedTemporaryFile(mode='w', delete=False) + self.config_file = NamedTemporaryFile(mode="w", delete=False) self.config_file.close() - self.pwd_file = NamedTemporaryFile(mode='w', delete=False) + self.pwd_file = NamedTemporaryFile(mode="w", delete=False) self.pwd_file.close() def tearDown(self): @@ -19,23 +19,25 @@ def tearDown(self): self.pwd_file = None def _set_file(self, file_name, value): - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(value) def test_given_no_pwd_file_expect_empty_credentials_list(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - "TABPY_TRANSFER_PROTOCOL = http") + self._set_file( + self.config_file.name, "[TabPy]\n" "TABPY_TRANSFER_PROTOCOL = http" + ) app = TabPyApp(self.config_file.name) self.assertDictEqual( - app.credentials, {}, - 'Expected no credentials with no password file provided') + app.credentials, + {}, + "Expected no credentials with no password file provided", + ) def test_given_empty_pwd_file_expect_app_fails(self): - self._set_file(self.config_file.name, - '[TabPy]\n' - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) self._set_file(self.pwd_file.name, "# just a comment") @@ -43,32 +45,27 @@ def test_given_empty_pwd_file_expect_app_fails(self): TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_missing_pwd_file_expect_app_fails(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - "TABPY_PWD_FILE = foo") + self._set_file(self.config_file.name, "[TabPy]\n" "TABPY_PWD_FILE = foo") with self.assertRaises(RuntimeError) as cm: TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_one_password_in_pwd_file_expect_one_credentials_entry(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) - login = 'user_name_123' - pwd = 'someting@something_else' - self._set_file(self.pwd_file.name, - "# passwords\n" - "\n" - f'{login} {pwd}') + login = "user_name_123" + pwd = "someting@something_else" + self._set_file(self.pwd_file.name, "# passwords\n" "\n" f"{login} {pwd}") app = TabPyApp(self.config_file.name) @@ -77,92 +74,82 @@ def test_given_one_password_in_pwd_file_expect_one_credentials_entry(self): self.assertEqual(app.credentials[login], pwd) def test_given_username_but_no_password_expect_parsing_fails(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) - login = 'user_name_123' - pwd = '' - self._set_file(self.pwd_file.name, - "# passwords\n" - "\n" - f'{login} {pwd}') + login = "user_name_123" + pwd = "" + self._set_file(self.pwd_file.name, "# passwords\n" "\n" f"{login} {pwd}") with self.assertRaises(RuntimeError) as cm: TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_duplicate_usernames_expect_parsing_fails(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) - login = 'user_name_123' - pwd = 'hashedpw' - self._set_file(self.pwd_file.name, - "# passwords\n" - "\n" - f'{login} {pwd}\n{login} {pwd}') + login = "user_name_123" + pwd = "hashedpw" + self._set_file( + self.pwd_file.name, "# passwords\n" "\n" f"{login} {pwd}\n{login} {pwd}" + ) with self.assertRaises(RuntimeError) as cm: TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_one_line_with_too_many_params_expect_app_fails(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) - self._set_file(self.pwd_file.name, - "# passwords\n" - "user1 pwd1\n" - "user_2 pwd#2" - "user1 pwd@3") + self._set_file( + self.pwd_file.name, + "# passwords\n" "user1 pwd1\n" "user_2 pwd#2" "user1 pwd@3", + ) with self.assertRaises(RuntimeError) as cm: TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_different_cases_in_pwd_file_expect_app_fails(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) - self._set_file(self.pwd_file.name, - "# passwords\n" - "user1 pwd1\n" - "user_2 pwd#2" - "UseR1 pwd@3") + self._set_file( + self.pwd_file.name, + "# passwords\n" "user1 pwd1\n" "user_2 pwd#2" "UseR1 pwd@3", + ) with self.assertRaises(RuntimeError) as cm: TabPyApp(self.config_file.name) ex = cm.exception self.assertEqual( - f'Failed to read password file {self.pwd_file.name}', - ex.args[0]) + f"Failed to read password file {self.pwd_file.name}", ex.args[0] + ) def test_given_multiple_credentials_expect_all_parsed(self): - self._set_file(self.config_file.name, - "[TabPy]\n" - f'TABPY_PWD_FILE = {self.pwd_file.name}') - creds = { - 'user_1': 'pwd_1', - 'user@2': 'pwd@2', - 'user#3': 'pwd#3' - } + self._set_file( + self.config_file.name, "[TabPy]\n" f"TABPY_PWD_FILE = {self.pwd_file.name}" + ) + creds = {"user_1": "pwd_1", "user@2": "pwd@2", "user#3": "pwd#3"} pwd_file_context = "" for login in creds: - pwd_file_context += f'{login} {creds[login]}\n' + pwd_file_context += f"{login} {creds[login]}\n" self._set_file(self.pwd_file.name, pwd_file_context) app = TabPyApp(self.config_file.name) diff --git a/tests/unit/server_tests/test_service_info_handler.py b/tests/unit/server_tests/test_service_info_handler.py index 9f61eeb5..585e47b4 100644 --- a/tests/unit/server_tests/test_service_info_handler.py +++ b/tests/unit/server_tests/test_service_info_handler.py @@ -10,12 +10,12 @@ def _create_expected_info_response(settings, tabpy_state): return { - 'description': tabpy_state.get_description(), - 'creation_time': tabpy_state.creation_time, - 'state_path': settings['state_file_path'], - 'server_version': settings[SettingsParameters.ServerVersion], - 'name': tabpy_state.name, - 'versions': settings['versions'] + "description": tabpy_state.get_description(), + "creation_time": tabpy_state.creation_time, + "state_path": settings["state_file_path"], + "server_version": settings[SettingsParameters.ServerVersion], + "name": tabpy_state.name, + "versions": settings["versions"], } @@ -23,9 +23,9 @@ class TestServiceInfoHandlerDefault(AsyncHTTPTestCase): @classmethod def setUpClass(cls): cls.patcher = patch( - 'tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace( - config=None)) + "tabpy.tabpy_server.app.app.TabPyApp._parse_cli_arguments", + return_value=Namespace(config=None), + ) cls.patcher.start() @classmethod @@ -37,11 +37,12 @@ def get_app(self): return self.app._create_tornado_web_app() def test_given_vanilla_tabpy_server_expect_correct_info_response(self): - response = self.fetch('/info') + response = self.fetch("/info") self.assertEqual(response.code, 200) actual_response = json.loads(response.body) expected_response = _create_expected_info_response( - self.app.settings, self.app.tabpy_state) + self.app.settings, self.app.tabpy_state + ) self.assertDictEqual(actual_response, expected_response) @@ -49,41 +50,47 @@ def test_given_vanilla_tabpy_server_expect_correct_info_response(self): class TestServiceInfoHandlerWithAuth(AsyncHTTPTestCase): @classmethod def setUpClass(cls): - prefix = '__TestServiceInfoHandlerWithAuth_' + prefix = "__TestServiceInfoHandlerWithAuth_" # create password file cls.pwd_file = tempfile.NamedTemporaryFile( - prefix=prefix, suffix='.txt', delete=False) - cls.pwd_file.write(b'username password') + prefix=prefix, suffix=".txt", delete=False + ) + cls.pwd_file.write(b"username password") cls.pwd_file.close() # create state.ini dir and file cls.state_dir = tempfile.mkdtemp(prefix=prefix) - cls.state_file = open(os.path.join(cls.state_dir, 'state.ini'), 'w+') - cls.state_file.write('[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + cls.state_file = open(os.path.join(cls.state_dir, "state.ini"), "w+") + cls.state_file.write( + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) cls.state_file.close() # create config file cls.config_file = tempfile.NamedTemporaryFile( - prefix=prefix, suffix='.conf', delete=False) + prefix=prefix, suffix=".conf", delete=False + ) cls.config_file.write( bytes( - '[TabPy]\n' - f'TABPY_PWD_FILE = {cls.pwd_file.name}\n' - f'TABPY_STATE_PATH = {cls.state_dir}', - 'utf-8')) + "[TabPy]\n" + f"TABPY_PWD_FILE = {cls.pwd_file.name}\n" + f"TABPY_STATE_PATH = {cls.state_dir}", + "utf-8", + ) + ) cls.config_file.close() @classmethod @@ -98,60 +105,57 @@ def get_app(self): return self.app._create_tornado_web_app() def test_given_tabpy_server_with_auth_expect_correct_info_response(self): - response = self.fetch('/info') + response = self.fetch("/info") self.assertEqual(response.code, 200) actual_response = json.loads(response.body) expected_response = _create_expected_info_response( - self.app.settings, self.app.tabpy_state) + self.app.settings, self.app.tabpy_state + ) self.assertDictEqual(actual_response, expected_response) - self.assertTrue('versions' in actual_response) - versions = actual_response['versions'] - self.assertTrue('v1' in versions) - v1 = versions['v1'] - self.assertTrue('features' in v1) - features = v1['features'] - self.assertDictEqual({ - 'authentication': { - 'methods': { - 'basic-auth': {} - }, - 'required': True, - } - }, features) + self.assertTrue("versions" in actual_response) + versions = actual_response["versions"] + self.assertTrue("v1" in versions) + v1 = versions["v1"] + self.assertTrue("features" in v1) + features = v1["features"] + self.assertDictEqual( + {"authentication": {"methods": {"basic-auth": {}}, "required": True,}}, + features, + ) class TestServiceInfoHandlerWithoutAuth(AsyncHTTPTestCase): @classmethod def setUpClass(cls): - prefix = '__TestServiceInfoHandlerWithoutAuth_' + prefix = "__TestServiceInfoHandlerWithoutAuth_" # create state.ini dir and file cls.state_dir = tempfile.mkdtemp(prefix=prefix) - with open(os.path.join(cls.state_dir, 'state.ini'), 'w+')\ - as cls.state_file: - cls.state_file.write('[Service Info]\n' - 'Name = TabPy Serve\n' - 'Description = \n' - 'Creation Time = 0\n' - 'Access-Control-Allow-Origin = \n' - 'Access-Control-Allow-Headers = \n' - 'Access-Control-Allow-Methods = \n' - '\n' - '[Query Objects Service Versions]\n' - '\n' - '[Query Objects Docstrings]\n' - '\n' - '[Meta]\n' - 'Revision Number = 1\n') + with open(os.path.join(cls.state_dir, "state.ini"), "w+") as cls.state_file: + cls.state_file.write( + "[Service Info]\n" + "Name = TabPy Serve\n" + "Description = \n" + "Creation Time = 0\n" + "Access-Control-Allow-Origin = \n" + "Access-Control-Allow-Headers = \n" + "Access-Control-Allow-Methods = \n" + "\n" + "[Query Objects Service Versions]\n" + "\n" + "[Query Objects Docstrings]\n" + "\n" + "[Meta]\n" + "Revision Number = 1\n" + ) cls.state_file.close() # create config file cls.config_file = tempfile.NamedTemporaryFile( - prefix=prefix, suffix='.conf', delete=False, mode='w+') - cls.config_file.write( - '[TabPy]\n' - f'TABPY_STATE_PATH = {cls.state_dir}') + prefix=prefix, suffix=".conf", delete=False, mode="w+" + ) + cls.config_file.write("[TabPy]\n" f"TABPY_STATE_PATH = {cls.state_dir}") cls.config_file.close() @classmethod @@ -165,17 +169,18 @@ def get_app(self): return self.app._create_tornado_web_app() def test_tabpy_server_with_no_auth_expect_correct_info_response(self): - response = self.fetch('/info') + response = self.fetch("/info") self.assertEqual(response.code, 200) actual_response = json.loads(response.body) expected_response = _create_expected_info_response( - self.app.settings, self.app.tabpy_state) + self.app.settings, self.app.tabpy_state + ) self.assertDictEqual(actual_response, expected_response) - self.assertTrue('versions' in actual_response) - versions = actual_response['versions'] - self.assertTrue('v1' in versions) - v1 = versions['v1'] - self.assertTrue('features' in v1) - features = v1['features'] + self.assertTrue("versions" in actual_response) + versions = actual_response["versions"] + self.assertTrue("v1" in versions) + v1 = versions["v1"] + self.assertTrue("features" in v1) + features = v1["features"] self.assertDictEqual({}, features) diff --git a/tests/unit/tools_tests/test_client.py b/tests/unit/tools_tests/test_client.py index 62ffccb0..f5df51ff 100644 --- a/tests/unit/tools_tests/test_client.py +++ b/tests/unit/tools_tests/test_client.py @@ -6,7 +6,6 @@ class TestClient(unittest.TestCase): - def setUp(self): self.client = Client("http://example.com/") self.client._service = Mock() # TODO: should spec this @@ -20,16 +19,14 @@ def test_init(self): self.assertEqual(client._endpoint, "http://example.com/") - client = Client( - endpoint="https://example.com/", - query_timeout=-10.0) + client = Client(endpoint="https://example.com/", query_timeout=-10.0) self.assertEqual(client._endpoint, "https://example.com/") self.assertEqual(client.query_timeout, 0.0) # valid name tests with self.assertRaises(ValueError): - Client('') + Client("") with self.assertRaises(TypeError): Client(1.0) with self.assertRaises(ValueError): @@ -61,8 +58,7 @@ def test_query(self): self.assertEqual(self.client.query("foo", a=1, b=2, c=3), "ok") - self.client._service.query.assert_called_once_with( - "foo", a=1, b=2, c=3) + self.client._service.query.assert_called_once_with("foo", a=1, b=2, c=3) def test_get_endpoints(self): self.client._service.get_endpoints.return_value = "foo" @@ -72,8 +68,9 @@ def test_get_endpoints(self): self.client._service.get_endpoints.assert_called_once_with("foo") def test_get_endpoint_upload_destination(self): - self.client._service.get_endpoint_upload_destination.return_value = \ - {"path": "foo"} + self.client._service.get_endpoint_upload_destination.return_value = { + "path": "foo" + } self.assertEqual(self.client._get_endpoint_upload_destination(), "foo") @@ -81,14 +78,15 @@ def test_set_credentials(self): username, password = "username", "password" self.client.set_credentials(username, password) - self.client._service.set_credentials.assert_called_once_with( - username, password) + self.client._service.set_credentials.assert_called_once_with(username, password) def test_check_invalid_endpoint_name(self): - endpoint_name = 'Invalid:model:@name' + endpoint_name = "Invalid:model:@name" with self.assertRaises(ValueError) as err: _check_endpoint_name(endpoint_name) - self.assertEqual(err.exception.args[0], - f'endpoint name {endpoint_name } can only contain: ' - 'a-z, A-Z, 0-9, underscore, hyphens and spaces.') + self.assertEqual( + err.exception.args[0], + f"endpoint name {endpoint_name } can only contain: " + "a-z, A-Z, 0-9, underscore, hyphens and spaces.", + ) diff --git a/tests/unit/tools_tests/test_rest.py b/tests/unit/tools_tests/test_rest.py index 530abfee..9543005c 100644 --- a/tests/unit/tools_tests/test_rest.py +++ b/tests/unit/tools_tests/test_rest.py @@ -2,13 +2,12 @@ import requests from requests.auth import HTTPBasicAuth import sys -from tabpy.tabpy_tools.rest import (RequestsNetworkWrapper, ServiceClient) +from tabpy.tabpy_tools.rest import RequestsNetworkWrapper, ServiceClient import unittest from unittest.mock import Mock class TestRequestsNetworkWrapper(unittest.TestCase): - def test_init(self): RequestsNetworkWrapper() @@ -21,7 +20,7 @@ def test_init_with_session(self): def mock_response(self, status_code): response = Mock(requests.Response()) - response.json.return_value = 'json' + response.json.return_value = "json" response.status_code = status_code return response @@ -36,94 +35,93 @@ def setUp(self): self.rnw = RequestsNetworkWrapper(session=session) def test_GET(self): - url = 'abc' - data = {'foo': 'bar'} - self.assertEqual(self.rnw.GET(url, data), 'json') + url = "abc" + data = {"foo": "bar"} + self.assertEqual(self.rnw.GET(url, data), "json") self.rnw.session.get.assert_called_once_with( - url, - params=data, - timeout=None, - auth=None) + url, params=data, timeout=None, auth=None + ) def test_GET_InvalidData(self): - url = 'abc' - data = {'cat'} + url = "abc" + data = {"cat"} with self.assertRaises(TypeError): self.rnw.session.get.return_value = self.mock_response(404) self.rnw.GET(url, data) def test_GET_InvalidURL(self): - url = '' - data = {'foo': 'bar'} + url = "" + data = {"foo": "bar"} with self.assertRaises(TypeError): self.rnw.session.get.return_value = self.mock_response(404) self.rnw.GET(url, data) def test_POST(self): - url = 'abc' - data = {'foo': 'bar'} - self.assertEqual(self.rnw.POST(url, data), 'json') + url = "abc" + data = {"foo": "bar"} + self.assertEqual(self.rnw.POST(url, data), "json") self.rnw.session.post.assert_called_once_with( - url, data=json.dumps(data), headers={ - 'content-type': 'application/json'}, + url, + data=json.dumps(data), + headers={"content-type": "application/json"}, timeout=None, - auth=None) + auth=None, + ) def test_POST_InvalidURL(self): - url = '' - data = {'foo': 'bar'} + url = "" + data = {"foo": "bar"} with self.assertRaises(TypeError): self.rnw.session.post.return_value = self.mock_response(404) self.rnw.POST(url, data) def test_POST_InvalidData(self): - url = 'url' - data = {'cat'} + url = "url" + data = {"cat"} with self.assertRaises(TypeError): self.rnw.POST(url, data) def test_PUT(self): - url = 'abc' - data = {'foo': 'bar'} - self.assertEqual(self.rnw.PUT(url, data), 'json') + url = "abc" + data = {"foo": "bar"} + self.assertEqual(self.rnw.PUT(url, data), "json") self.rnw.session.put.assert_called_once_with( url, data=json.dumps(data), - headers={'content-type': 'application/json'}, + headers={"content-type": "application/json"}, timeout=None, - auth=None) + auth=None, + ) def test_PUT_InvalidData(self): - url = 'url' - data = {'cat'} + url = "url" + data = {"cat"} with self.assertRaises(TypeError): self.rnw.PUT(url, data) def test_PUT_InvalidURL(self): - url = '' - data = {'foo:bar'} + url = "" + data = {"foo:bar"} with self.assertRaises(TypeError): self.rnw.PUT(url, data) def test_DELETE(self): - url = 'abc' - data = {'foo': 'bar'} + url = "abc" + data = {"foo": "bar"} self.assertIs(self.rnw.DELETE(url, data), None) self.rnw.session.delete.assert_called_once_with( - url, - data=json.dumps(data), - timeout=None, - auth=None) + url, data=json.dumps(data), timeout=None, auth=None + ) def test_DELETE_InvalidData(self): - url = 'abc' - data = {'cat'} + url = "abc" + data = {"cat"} with self.assertRaises(TypeError): self.rnw.DELETE(url, data) def test_DELETE_InvalidURL(self): - url = '' - data = {'foo:bar'} + url = "" + data = {"foo:bar"} with self.assertRaises(TypeError): self.rnw.DELETE(url, data) @@ -131,106 +129,107 @@ def test_set_credentials(self): expected_auth = None self.assertEqual(self.rnw.auth, expected_auth) - username, password = 'username', 'password' + username, password = "username", "password" expected_auth = HTTPBasicAuth(username, password) self.rnw.set_credentials(username, password) self.assertEqual(self.rnw.auth, expected_auth) def _test_METHOD_with_credentials( - self, - http_method_function, - http_session_method_function, - headers=None, - params=False, - data=False, - response=None): - username, password = 'username', 'password' + self, + http_method_function, + http_session_method_function, + headers=None, + params=False, + data=False, + response=None, + ): + username, password = "username", "password" self.rnw.set_credentials(username, password) - url = 'url' - _data = {'foo': 'bar'} + url = "url" + _data = {"foo": "bar"} self.assertEqual(http_method_function(url, _data), response) pargs = {url} - kwargs = {'timeout': None, 'auth': self.rnw.auth} + kwargs = {"timeout": None, "auth": self.rnw.auth} if data: - kwargs['data'] = json.dumps(_data) + kwargs["data"] = json.dumps(_data) if headers: - kwargs['headers'] = headers + kwargs["headers"] = headers if params: - kwargs['params'] = _data + kwargs["params"] = _data http_session_method_function.assert_called_once_with(*pargs, **kwargs) self.assertEqual(self.rnw.auth, HTTPBasicAuth(username, password)) def test_GET_with_credentials(self): self._test_METHOD_with_credentials( - self.rnw.GET, - self.rnw.session.get, - params=True, - response='json') + self.rnw.GET, self.rnw.session.get, params=True, response="json" + ) def test_POST_with_credentials(self): self._test_METHOD_with_credentials( self.rnw.POST, self.rnw.session.post, - headers={ - 'content-type': 'application/json' - }, + headers={"content-type": "application/json"}, data=True, - response='json') + response="json", + ) def test_PUT_with_credentials(self): self._test_METHOD_with_credentials( - self.rnw.PUT, self.rnw.session.put, data=True, headers={ - 'content-type': 'application/json'}, response='json') + self.rnw.PUT, + self.rnw.session.put, + data=True, + headers={"content-type": "application/json"}, + response="json", + ) def test_DELETE_with_credentials(self): self._test_METHOD_with_credentials( - self.rnw.DELETE, self.rnw.session.delete, data=True) + self.rnw.DELETE, self.rnw.session.delete, data=True + ) class TestServiceClient(unittest.TestCase): - def setUp(self): nw = Mock(RequestsNetworkWrapper()) - nw.GET.return_value = 'GET' - nw.POST.return_value = 'POST' - nw.PUT.return_value = 'PUT' - nw.DELETE.return_value = 'DELETE' + nw.GET.return_value = "GET" + nw.POST.return_value = "POST" + nw.PUT.return_value = "PUT" + nw.DELETE.return_value = "DELETE" - self.sc = ServiceClient('endpoint/', network_wrapper=nw) - self.scClientDoesNotEndWithSlash =\ - ServiceClient('endpoint', network_wrapper=nw) + self.sc = ServiceClient("endpoint/", network_wrapper=nw) + self.scClientDoesNotEndWithSlash = ServiceClient("endpoint", network_wrapper=nw) def test_GET(self): - self.assertEqual(self.sc.GET('test'), 'GET') - self.sc.network_wrapper.GET.assert_called_once_with('endpoint/test', - None, None) + self.assertEqual(self.sc.GET("test"), "GET") + self.sc.network_wrapper.GET.assert_called_once_with("endpoint/test", None, None) def test_POST(self): - self.assertEqual(self.sc.POST('test'), 'POST') - self.sc.network_wrapper.POST.assert_called_once_with('endpoint/test', - None, None) + self.assertEqual(self.sc.POST("test"), "POST") + self.sc.network_wrapper.POST.assert_called_once_with( + "endpoint/test", None, None + ) def test_PUT(self): - self.assertEqual(self.sc.PUT('test'), 'PUT') - self.sc.network_wrapper.PUT.assert_called_once_with('endpoint/test', - None, None) + self.assertEqual(self.sc.PUT("test"), "PUT") + self.sc.network_wrapper.PUT.assert_called_once_with("endpoint/test", None, None) def test_DELETE(self): - self.assertEqual(self.sc.DELETE('test'), None) - self.sc.network_wrapper.DELETE.assert_called_once_with('endpoint/test', - None, None) + self.assertEqual(self.sc.DELETE("test"), None) + self.sc.network_wrapper.DELETE.assert_called_once_with( + "endpoint/test", None, None + ) def test_FixEndpoint(self): - self.assertEqual(self.scClientDoesNotEndWithSlash.GET('test'), 'GET') - self.sc.network_wrapper.GET.assert_called_once_with('endpoint/test', - None, None) + self.assertEqual(self.scClientDoesNotEndWithSlash.GET("test"), "GET") + self.sc.network_wrapper.GET.assert_called_once_with("endpoint/test", None, None) def test_set_credentials(self): - username, password = 'username', 'password' + username, password = "username", "password" self.sc.set_credentials(username, password) self.sc.network_wrapper.set_credentials.assert_called_once_with( - username, password) + username, password + ) diff --git a/tests/unit/tools_tests/test_rest_object.py b/tests/unit/tools_tests/test_rest_object.py index ec96511e..c5fafa17 100644 --- a/tests/unit/tools_tests/test_rest_object.py +++ b/tests/unit/tools_tests/test_rest_object.py @@ -5,9 +5,7 @@ class TestRESTObject(unittest.TestCase): - def test_new_class(self): - class FooObject(RESTObject): f = RESTProperty(float) i = RESTProperty(int) @@ -22,32 +20,31 @@ class FooObject(RESTObject): with self.assertRaises(AttributeError): f.e - self.assertEqual(f['f'], 6.0) - self.assertEqual(f['i'], 3) - self.assertEqual(f['s'], "hello!") + self.assertEqual(f["f"], 6.0) + self.assertEqual(f["i"], 3) + self.assertEqual(f["s"], "hello!") with self.assertRaises(KeyError): - f['e'] + f["e"] with self.assertRaises(KeyError): - f['cat'] + f["cat"] with self.assertRaises(KeyError): - f['cat'] = 5 + f["cat"] = 5 self.assertEqual(len(f), 3) - self.assertEqual(set(f), set(['f', 'i', 's'])) - self.assertEqual(set(f.keys()), set(['f', 'i', 's'])) + self.assertEqual(set(f), set(["f", "i", "s"])) + self.assertEqual(set(f.keys()), set(["f", "i", "s"])) self.assertEqual(set(f.values()), set([6.0, 3, "hello!"])) - self.assertEqual(set(f.items()), set( - [('f', 6.0), ('i', 3), ('s', "hello!")])) + self.assertEqual(set(f.items()), set([("f", 6.0), ("i", 3), ("s", "hello!")])) f.e = "a" self.assertEqual(f.e, "a") - self.assertEqual(f['e'], "a") - f['e'] = 'b' + self.assertEqual(f["e"], "a") + f["e"] = "b" self.assertEqual(f.e, "b") with self.assertRaises(ValueError): - f.e = 'fubar' + f.e = "fubar" f.f = sys.float_info.max self.assertEqual(f.f, sys.float_info.max) diff --git a/tests/unit/tools_tests/test_schema.py b/tests/unit/tools_tests/test_schema.py index 0b7802d2..4101b79a 100755 --- a/tests/unit/tools_tests/test_schema.py +++ b/tests/unit/tools_tests/test_schema.py @@ -6,36 +6,30 @@ class TestSchema(unittest.TestCase): - def test_schema(self): schema = generate_schema( - input={'x': ['happy', 'sad', 'neutral']}, - input_description={'x': 'text to analyze'}, - output=[.98, -0.99, 0], - output_description='scores for input texts') + input={"x": ["happy", "sad", "neutral"]}, + input_description={"x": "text to analyze"}, + output=[0.98, -0.99, 0], + output_description="scores for input texts", + ) expected = { - 'input': { - 'type': 'object', - 'properties': { - 'x': { - 'type': 'array', - 'items': { - 'type': 'string' - }, - 'description': 'text to analyze' + "input": { + "type": "object", + "properties": { + "x": { + "type": "array", + "items": {"type": "string"}, + "description": "text to analyze", } }, - 'required': ['x'] + "required": ["x"], }, - 'sample': { - 'x': ['happy', 'sad', 'neutral'] + "sample": {"x": ["happy", "sad", "neutral"]}, + "output": { + "type": "array", + "items": {"type": "number"}, + "description": "scores for input texts", }, - 'output': { - 'type': 'array', - 'items': { - 'type': 'number' - }, - 'description': 'scores for input texts' - } } self.assertEqual(schema, expected)