Skip to content

Commit

Permalink
Support EDGEDB_WAIT_UNTIL_AVAILABLE environment variable
Browse files Browse the repository at this point in the history
fixes #302
  • Loading branch information
fmoor committed Jun 27, 2022
1 parent 09bfcc8 commit c27ab36
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 34 deletions.
137 changes: 136 additions & 1 deletion edgedb/con_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@
errno.ENOENT,
})

ISO_DURATION_RE = re.compile(
r'^PT' +
r'(?:(?P<hours>\d*\.?\d*)H)?' +
r'(?:(?P<minutes>\d*\.?\d*)M)?' +
r'(?:(?P<seconds>\d*\.?\d*)S)?$'
)
HUMAN_DURATION_HOUR_RE = re.compile(
r'(?P<number>\d+|\d+\.\d+|\.\d+)\s*(?:h|hours?)(?P<tail>\s|$|\d)'
)
HUMAN_DURATION_MINTUE_RE = re.compile(
r'(?P<number>\d+|\d+\.\d+|\.\d+)\s*(?:m|minutes?)(?P<tail>\s|$|\d)'
)
HUMAN_DURATION_SECOND_RE = re.compile(
r'(?P<number>\d+|\d+\.\d+|\.\d+)\s*(?:s|seconds?)(?P<tail>\s|$|\d)'
)
HUMAN_DURATION_MS_RE = re.compile(
r'(?P<number>\d+|\d+\.\d+|\.\d+)\s*(?:ms|milliseconds?)(?P<tail>\s|$|\d)'
)
HUMAN_DURATION_US_RE = re.compile(
r'(?P<number>\d+|\d+\.\d+|\.\d+)\s*(?:us|microseconds?)(?P<tail>\s|$|\d)'
)


class ClientConfiguration(typing.NamedTuple):

Expand Down Expand Up @@ -156,6 +178,8 @@ class ResolvedConnectConfig:
_tls_security = None
_tls_security_source = None

_wait_until_available = None

server_settings = {}

def _set_param(self, param, value, source, validator=None):
Expand Down Expand Up @@ -198,7 +222,22 @@ def set_tls_security(self, security, source):
self._set_param('tls_security', security, source,
_validate_tls_security)

def set_wait_until_available(self, wait_until_available, source):
self._set_param(
'wait_until_available',
wait_until_available,
source,
_validate_wait_until_available,
)

def add_server_settings(self, server_settings):
if 'wait_until_available' in server_settings:
import traceback
import sys
print(server_settings)
traceback.print_stack(file=sys.stdout)
print()
print()
_validate_server_settings(server_settings)
self.server_settings = {**server_settings, **self.server_settings}

Expand Down Expand Up @@ -284,6 +323,10 @@ def ssl_ctx(self):

return self._ssl_ctx

@property
def wait_until_available(self):
return self._wait_until_available or 30


def _validate_host(host):
if '/' in host:
Expand Down Expand Up @@ -326,6 +369,76 @@ def _validate_user(user):
return user


def _parse_iso_duration(string: str):
match = ISO_DURATION_RE.match(string)
if match is None:
raise ValueError(f"invalid duration {string!r}")

return (
3600 * float(match.group('hours') or '0') +
60 * float(match.group('minutes') or '0') +
float(match.group('seconds') or '0')
)


def _parse_human_duration_unit(re, string):
number = None
match = re.search(string)
if match:
number = match.group('number')
tail = match.group('tail')
string = re.sub(tail, string, count=1)

return number, string


def _parse_human_duration(string: str):
hour, string = _parse_human_duration_unit(HUMAN_DURATION_HOUR_RE, string)
minute, string = _parse_human_duration_unit(
HUMAN_DURATION_MINTUE_RE, string)
second, string = _parse_human_duration_unit(
HUMAN_DURATION_SECOND_RE, string)
ms, string = _parse_human_duration_unit(HUMAN_DURATION_MS_RE, string)
us, string = _parse_human_duration_unit(HUMAN_DURATION_US_RE, string)

if string.strip() != '':
raise ValueError(f'invalid duration {string!r}')

no_value = (
hour is None and
minute is None and
second is None and
ms is None and
us is None
)
if no_value:
raise ValueError(f"invalid duration {string!r}")

return (
3600 * float(hour or '0') +
60 * float(minute or '0') +
float(second or '0') +
0.001 * float(ms or '0') +
0.000001 * float(us or '0')
)


def _parse_duration_str(string: str):
if string.startswith('PT'):
return _parse_iso_duration(string)
return _parse_human_duration(string)


def _validate_wait_until_available(wait_until_available):
if isinstance(wait_until_available, str):
return _parse_duration_str(wait_until_available)

if isinstance(wait_until_available, (int, float)):
return wait_until_available

raise ValueError(f"invalid duration {wait_until_available!r}")


def _validate_server_settings(server_settings):
if (
not isinstance(server_settings, dict) or
Expand All @@ -351,6 +464,7 @@ def _parse_connect_dsn_and_args(
tls_ca_file,
tls_security,
server_settings,
wait_until_available,
):
resolved_config = ResolvedConnectConfig()

Expand Down Expand Up @@ -404,6 +518,10 @@ def _parse_connect_dsn_and_args(
(server_settings, '"server_settings" option')
if server_settings is not None else None
),
wait_until_available=(
(wait_until_available, '"wait_until_available" option')
if wait_until_available is not None else None
)
)

if has_compound_options is False:
Expand All @@ -427,6 +545,7 @@ def _parse_connect_dsn_and_args(
env_tls_ca = os.getenv('EDGEDB_TLS_CA')
env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE')
env_tls_security = os.getenv('EDGEDB_CLIENT_TLS_SECURITY')
env_wait_until_available = os.getenv('EDGEDB_WAIT_UNTIL_AVAILABLE')

has_compound_options = _resolve_config_options(
resolved_config,
Expand Down Expand Up @@ -479,6 +598,12 @@ def _parse_connect_dsn_and_args(
'"EDGEDB_CLIENT_TLS_SECURITY" environment variable')
if env_tls_security is not None else None
),
wait_until_available=(
(
env_wait_until_available,
'"EDGEDB_WAIT_UNTIL_AVAILABLE" environment variable'
) if env_wait_until_available is not None else None
)
)

if not has_compound_options:
Expand Down Expand Up @@ -628,6 +753,12 @@ def strip_leading_slash(str):
resolved_config.set_tls_security
)

handle_dsn_part(
'wait_until_available', None,
resolved_config._wait_until_available,
resolved_config.set_wait_until_available
)

resolved_config.add_server_settings(query)


Expand All @@ -648,6 +779,7 @@ def _resolve_config_options(
tls_ca_file=None,
tls_security=None,
server_settings=None,
wait_until_available=None,
):
if database is not None:
resolved_config.set_database(*database)
Expand All @@ -666,6 +798,8 @@ def _resolve_config_options(
resolved_config.set_tls_security(*tls_security)
if server_settings is not None:
resolved_config.add_server_settings(server_settings[0])
if wait_until_available is not None:
resolved_config.set_wait_until_available(*wait_until_available)

compound_params = [
dsn,
Expand Down Expand Up @@ -808,12 +942,13 @@ def parse_connect_arguments(
tls_ca_file=tls_ca_file,
tls_security=tls_security,
server_settings=server_settings,
wait_until_available=wait_until_available,
)

client_config = ClientConfiguration(
connect_timeout=timeout,
command_timeout=command_timeout,
wait_until_available=wait_until_available or 0,
wait_until_available=connect_config.wait_until_available,
)

return connect_config, client_config
Expand Down
2 changes: 1 addition & 1 deletion tests/shared-client-testcases
Loading

0 comments on commit c27ab36

Please sign in to comment.