Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use gunicorn by default if it's available #234

Merged
merged 3 commits into from
May 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions optuna_dashboard/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from socketserver import ThreadingMixIn
import sys
from typing import Literal
from wsgiref.simple_server import make_server
from wsgiref.simple_server import WSGIServer

Expand All @@ -17,28 +18,18 @@


DEBUG = os.environ.get("OPTUNA_DASHBOARD_DEBUG") == "1"
SERVER_CHOICES = ["wsgiref", "gunicorn"]
SERVER_CHOICES = ["auto", "wsgiref", "gunicorn"]


class ThreadedWSGIServer(ThreadingMixIn, WSGIServer):
pass


def run_wsgiref(app: Bottle, host: str, port: int, quiet: bool) -> None:
if DEBUG:
run(
app,
host=host,
port=port,
server="wsgiref",
quiet=quiet,
reloader=DEBUG,
)
else:
print(f"Listening on http://{host}:{port}/", file=sys.stderr)
print("Hit Ctrl-C to quit.\n", file=sys.stderr)
httpd = make_server(host, port, app, server_class=ThreadedWSGIServer)
httpd.serve_forever()
print(f"Listening on http://{host}:{port}/", file=sys.stderr)
print("Hit Ctrl-C to quit.\n", file=sys.stderr)
httpd = make_server(host, port, app, server_class=ThreadedWSGIServer)
httpd.serve_forever()


def run_gunicorn(app: Bottle, host: str, port: int, quiet: bool) -> None:
Expand All @@ -59,6 +50,31 @@ def load(self) -> Bottle:
Application().run()


def run_debug_server(app: Bottle, host: str, port: int, quiet: bool) -> None:
run(
app,
host=host,
port=port,
server="wsgiref",
quiet=quiet,
reloader=DEBUG,
)


def auto_select_server(
server_arg: Literal["auto", "gunicorn", "wsgiref"]
) -> Literal["gunicorn", "wsgiref"]:
if server_arg != "auto":
return server_arg

try:
import gunicorn # NOQA

return "gunicorn"
except ImportError:
return "wsgiref"


def main() -> None:
parser = argparse.ArgumentParser(description="Real-time dashboard for Optuna.")
parser.add_argument("storage", help="DB URL (e.g. sqlite:///example.db)", type=str)
Expand All @@ -71,7 +87,7 @@ def main() -> None:
parser.add_argument(
"--server",
help="server (default: %(default)s)",
default="wsgiref",
default="auto",
choices=SERVER_CHOICES,
)
parser.add_argument("--version", "-v", action="version", version=__version__)
Expand All @@ -89,9 +105,12 @@ def main() -> None:
if DEBUG and isinstance(storage, RDBStorage):
app = register_profiler_view(app, storage)

if args.server == "wsgiref":
server = auto_select_server(args.server)
if DEBUG:
run_debug_server(app, args.host, args.port, args.quiet)
elif server == "wsgiref":
run_wsgiref(app, args.host, args.port, args.quiet)
elif args.server == "gunicorn":
elif server == "gunicorn":
run_gunicorn(app, args.host, args.port, args.quiet)
else:
raise Exception("must not reach here")
Expand Down