Skip to content

Commit

Permalink
Add handling of branch option. (#484)
Browse files Browse the repository at this point in the history
The new `branch` option can be used instead of the old `database`
option. These are mutually exclusive and will generally produce an error
if used at the same time (except when one appears in settings that
completely override other sources of options).

The branch option can be passed in the following manner:
- as a kwarg `branch` to the client
- as `EDGEDB_BRANCH` environment variable
- as a `branch` field of the credentials
- as a `branch` query parameter in DSN

This binding will accept either `branch` or `database` option and will
normalize the result as `database` when connecting to the EdgeDB server
(for compatibility with older servers during the period of deprecation
of `database` term).
  • Loading branch information
vpetrovykh authored Mar 13, 2024
1 parent 701447d commit c666a6f
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 10 deletions.
6 changes: 6 additions & 0 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def remove_log_listener(
def dbname(self) -> str:
return self._params.database

@property
def branch(self) -> str:
return self._params.branch

@abc.abstractmethod
def is_closed(self) -> bool:
...
Expand Down Expand Up @@ -679,6 +683,7 @@ def __init__(
password: str = None,
secret_key: str = None,
database: str = None,
branch: str = None,
tls_ca: str = None,
tls_ca_file: str = None,
tls_security: str = None,
Expand All @@ -697,6 +702,7 @@ def __init__(
"password": password,
"secret_key": secret_key,
"database": database,
"branch": branch,
"timeout": timeout,
"tls_ca": tls_ca,
"tls_ca_file": tls_ca_file,
Expand Down
133 changes: 125 additions & 8 deletions edgedb/con_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,15 @@ class ResolvedConnectConfig:
_port = None
_port_source = None

# We keep track of database and branch separately, because we want to make
# sure that all the configuration is consistent and uses one or the other
# exclusively.
_database = None
_database_source = None

_branch = None
_branch_source = None

_user = None
_user_source = None

Expand Down Expand Up @@ -226,6 +232,9 @@ def set_port(self, port, source):
def set_database(self, database, source):
self._set_param('database', database, source, _validate_database)

def set_branch(self, branch, source):
self._set_param('branch', branch, source, _validate_branch)

def set_user(self, user, source):
self._set_param('user', user, source, _validate_user)

Expand Down Expand Up @@ -268,9 +277,24 @@ def address(self):
self._port if self._port else 5656
)

# The properties actually merge database and branch, but "default" is
# different. If you need to know the underlying config use the _database
# and _branch.
@property
def database(self):
return self._database if self._database else 'edgedb'
return (
self._database if self._database else
self._branch if self._branch else
'edgedb'
)

@property
def branch(self):
return (
self._database if self._database else
self._branch if self._branch else
'__default__'
)

@property
def user(self):
Expand Down Expand Up @@ -391,6 +415,12 @@ def _validate_database(database):
return database


def _validate_branch(branch):
if branch == '':
raise ValueError(f'invalid branch name: {branch}')
return branch


def _validate_user(user):
if user == '':
raise ValueError(f'invalid user name: {user}')
Expand Down Expand Up @@ -521,6 +551,7 @@ def _parse_connect_dsn_and_args(
password,
secret_key,
database,
branch,
tls_ca,
tls_ca_file,
tls_security,
Expand Down Expand Up @@ -557,6 +588,10 @@ def _parse_connect_dsn_and_args(
(database, '"database" option')
if database is not None else None
),
branch=(
(branch, '"branch" option')
if branch is not None else None
),
user=(user, '"user" option') if user is not None else None,
password=(
(password, '"password" option')
Expand Down Expand Up @@ -604,6 +639,7 @@ def _parse_connect_dsn_and_args(
env_credentials_file = os.getenv('EDGEDB_CREDENTIALS_FILE')
env_host = os.getenv('EDGEDB_HOST')
env_database = os.getenv('EDGEDB_DATABASE')
env_branch = os.getenv('EDGEDB_BRANCH')
env_user = os.getenv('EDGEDB_USER')
env_password = os.getenv('EDGEDB_PASSWORD')
env_secret_key = os.getenv('EDGEDB_SECRET_KEY')
Expand Down Expand Up @@ -643,6 +679,10 @@ def _parse_connect_dsn_and_args(
(env_database, '"EDGEDB_DATABASE" environment variable')
if env_database is not None else None
),
branch=(
(env_branch, '"EDGEDB_BRANCH" environment variable')
if env_branch is not None else None
),
user=(
(env_user, '"EDGEDB_USER" environment variable')
if env_user is not None else None
Expand Down Expand Up @@ -818,11 +858,52 @@ def handle_dsn_part(
def strip_leading_slash(str):
return str[1:] if str.startswith('/') else str

handle_dsn_part(
'database', strip_leading_slash(database),
resolved_config._database, resolved_config.set_database,
strip_leading_slash
)
if (
'branch' in query or
'branch_env' in query or
'branch_file' in query
):
if (
'database' in query or
'database_env' in query or
'database_file' in query
):
raise ValueError(
f"invalid DSN: `database` and `branch` cannot be present "
f"at the same time"
)
if resolved_config._database is not None:
raise errors.ClientConnectionError(
f"`branch` in DSN and {resolved_config._database_source} "
f"are mutually exclusive"
)
handle_dsn_part(
'branch', strip_leading_slash(database),
resolved_config._branch, resolved_config.set_branch,
strip_leading_slash
)
else:
if resolved_config._branch is not None:
if (
'database' in query or
'database_env' in query or
'database_file' in query
):
raise errors.ClientConnectionError(
f"`database` in DSN and {resolved_config._branch_source} "
f"are mutually exclusive"
)
handle_dsn_part(
'branch', strip_leading_slash(database),
resolved_config._branch, resolved_config.set_branch,
strip_leading_slash
)
else:
handle_dsn_part(
'database', strip_leading_slash(database),
resolved_config._database, resolved_config.set_database,
strip_leading_slash
)

handle_dsn_part(
'user', user, resolved_config._user, resolved_config.set_user
Expand Down Expand Up @@ -929,6 +1010,7 @@ def _resolve_config_options(
host=None,
port=None,
database=None,
branch=None,
user=None,
password=None,
secret_key=None,
Expand All @@ -940,7 +1022,23 @@ def _resolve_config_options(
cloud_profile=None,
):
if database is not None:
if branch is not None:
raise errors.ClientConnectionError(
f"{database[1]} and {branch[1]} are mutually exclusive"
)
if resolved_config._branch is not None:
raise errors.ClientConnectionError(
f"{database[1]} and {resolved_config._branch_source} are "
f"mutually exclusive"
)
resolved_config.set_database(*database)
if branch is not None:
if resolved_config._database is not None:
raise errors.ClientConnectionError(
f"{resolved_config._database_source} and {branch[1]} are "
f"mutually exclusive"
)
resolved_config.set_branch(*branch)
if user is not None:
resolved_config.set_user(*user)
if password is not None:
Expand All @@ -950,7 +1048,8 @@ def _resolve_config_options(
if tls_ca_file is not None:
if tls_ca is not None:
raise errors.ClientConnectionError(
f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive")
f"{tls_ca[1]} and {tls_ca_file[1]} are mutually exclusive"
)
resolved_config.set_tls_ca_file(*tls_ca_file)
if tls_ca is not None:
resolved_config.set_tls_ca_data(*tls_ca)
Expand Down Expand Up @@ -1018,7 +1117,23 @@ def _resolve_config_options(

resolved_config.set_host(creds.get('host'), source)
resolved_config.set_port(creds.get('port'), source)
resolved_config.set_database(creds.get('database'), source)
# We know that credentials have been validated, but they might be
# inconsistent with other resolved config settings.
if 'database' in creds:
if resolved_config._branch is not None:
raise errors.ClientConnectionError(
f"`branch` in configuration and `database` "
f"in credentials are mutually exclusive"
)
resolved_config.set_database(creds.get('database'), source)

elif 'branch' in creds:
if resolved_config._database is not None:
raise errors.ClientConnectionError(
f"`database` in configuration and `branch` "
f"in credentials are mutually exclusive"
)
resolved_config.set_branch(creds.get('branch'), source)
resolved_config.set_user(creds.get('user'), source)
resolved_config.set_password(creds.get('password'), source)
resolved_config.set_tls_ca_data(creds.get('tls_ca'), source)
Expand Down Expand Up @@ -1068,6 +1183,7 @@ def parse_connect_arguments(
credentials,
credentials_file,
database,
branch,
user,
password,
secret_key,
Expand Down Expand Up @@ -1100,6 +1216,7 @@ def parse_connect_arguments(
credentials=credentials,
credentials_file=credentials_file,
database=database,
branch=branch,
user=user,
password=password,
secret_key=secret_key,
Expand Down
11 changes: 11 additions & 0 deletions edgedb/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class RequiredCredentials(typing.TypedDict, total=True):
class Credentials(RequiredCredentials, total=False):
host: typing.Optional[str]
password: typing.Optional[str]
# Either database or branch may appear in credentials, but not both.
database: typing.Optional[str]
branch: typing.Optional[str]
tls_ca: typing.Optional[str]
tls_security: typing.Optional[str]

Expand Down Expand Up @@ -64,6 +66,15 @@ def validate_credentials(data: dict) -> Credentials:
raise ValueError("`database` must be a string")
result['database'] = database

branch = data.get('branch')
if branch is not None:
if not isinstance(branch, str):
raise ValueError("`branch` must be a string")
if database is not None:
raise ValueError(
f"`database` and `branch` cannot both be set")
result['branch'] = branch

password = data.get('password')
if password is not None:
if not isinstance(password, str):
Expand Down
2 changes: 1 addition & 1 deletion tests/shared-client-testcases
31 changes: 30 additions & 1 deletion tests/test_con_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def run_testcase(self, testcase):
host = opts.get('host')
port = opts.get('port')
database = opts.get('database')
branch = opts.get('branch')
user = opts.get('user')
password = opts.get('password')
secret_key = opts.get('secretKey')
Expand Down Expand Up @@ -233,6 +234,7 @@ def mocked_open(filepath, *args, **kwargs):
credentials=credentials,
credentials_file=credentials_file,
database=database,
branch=branch,
user=user,
password=password,
secret_key=secret_key,
Expand All @@ -250,6 +252,7 @@ def mocked_open(filepath, *args, **kwargs):
connect_config.address[0], connect_config.address[1]
],
'database': connect_config.database,
'branch': connect_config.branch,
'user': connect_config.user,
'password': connect_config.password,
'secretKey': connect_config.secret_key,
Expand Down Expand Up @@ -289,7 +292,7 @@ def test_test_connect_params_environ(self):
if key in os.environ:
del os.environ[key]

def test_test_connect_params_run_testcase(self):
def test_test_connect_params_run_testcase_01(self):
with self.environ(EDGEDB_PORT='777'):
self.run_testcase({
'env': {
Expand All @@ -301,6 +304,31 @@ def test_test_connect_params_run_testcase(self):
'result': {
'address': ['abc', 5656],
'database': 'edgedb',
'branch': '__default__',
'user': '__test__',
'password': None,
'secretKey': None,
'tlsCAData': None,
'tlsSecurity': 'strict',
'serverSettings': {},
'waitUntilAvailable': 30,
},
})

def test_test_connect_params_run_testcase_02(self):
with self.environ(EDGEDB_PORT='777'):
self.run_testcase({
'env': {
'EDGEDB_HOST': 'abc'
},
'opts': {
'user': '__test__',
'branch': 'new_branch',
},
'result': {
'address': ['abc', 5656],
'database': 'new_branch',
'branch': 'new_branch',
'user': '__test__',
'password': None,
'secretKey': None,
Expand Down Expand Up @@ -399,6 +427,7 @@ def test_project_config(self):
password=None,
secret_key=None,
database=None,
branch=None,
tls_ca=None,
tls_ca_file=None,
tls_security=None,
Expand Down

0 comments on commit c666a6f

Please sign in to comment.