Skip to content

Commit

Permalink
Enable json serialization for secrets backend
Browse files Browse the repository at this point in the history
Previously in general we could only store connections in the Airflow URI format.  With this change we can serialize as json.  The Airflow URI format can be very tricky to work with and although we have for some time had a convenience method Connection.get_uri, using json is just simpler.
  • Loading branch information
dstandish committed Mar 6, 2022
1 parent 80c52a1 commit 85beb36
Show file tree
Hide file tree
Showing 12 changed files with 663 additions and 334 deletions.
41 changes: 34 additions & 7 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def string_list_type(val):
return [x.strip() for x in val.split(',')]


def string_lower_type(val):
"""Lowers arg"""
if not val:
return
return val.strip().lower()


# Shared
ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag")
ARG_TASK_ID = Arg(("task_id",), help="The id of the task")
Expand Down Expand Up @@ -701,6 +708,9 @@ def string_list_type(val):
ARG_CONN_URI = Arg(
('--conn-uri',), help='Connection URI, required to add a connection without conn_type', type=str
)
ARG_CONN_JSON = Arg(
('--conn-json',), help='Connection JSON, required to add a connection using JSON representation', type=str
)
ARG_CONN_TYPE = Arg(
('--conn-type',), help='Connection type, required to add a connection without conn_uri', type=str
)
Expand All @@ -725,7 +735,19 @@ def string_list_type(val):
type=argparse.FileType('w', encoding='UTF-8'),
)
ARG_CONN_EXPORT_FORMAT = Arg(
('--format',), help='Format of the connections data in file', type=str, choices=['json', 'yaml', 'env']
('--format',),
help='Deprecated -- use `--file-format` instead. File format to use for the export.',
type=str,
choices=['json', 'yaml', 'env'],
)
ARG_CONN_EXPORT_FILE_FORMAT = Arg(
('--file-format',), help='File format for the export', type=str, choices=['json', 'yaml', 'env']
)
ARG_CONN_SERIALIZATION_FORMAT = Arg(
('--serialization-format',),
help='When exporting as `.env` format, defines how connections should be serialized. Default is `uri`.',
type=string_lower_type,
choices=['json', 'uri'],
)
ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file")

Expand Down Expand Up @@ -1418,7 +1440,7 @@ class GroupCommand(NamedTuple):
name='add',
help='Add a connection',
func=lazy_load_command('airflow.cli.commands.connection_command.connections_add'),
args=(ARG_CONN_ID, ARG_CONN_URI, ARG_CONN_EXTRA) + tuple(ALTERNATIVE_CONN_SPECS_ARGS),
args=(ARG_CONN_ID, ARG_CONN_URI, ARG_CONN_JSON, ARG_CONN_EXTRA) + tuple(ALTERNATIVE_CONN_SPECS_ARGS),
),
ActionCommand(
name='delete',
Expand All @@ -1432,19 +1454,24 @@ class GroupCommand(NamedTuple):
description=(
"All connections can be exported in STDOUT using the following command:\n"
"airflow connections export -\n"
"The file format can be determined by the provided file extension. eg, The following "
"The file format can be determined by the provided file extension. E.g., The following "
"command will export the connections in JSON format:\n"
"airflow connections export /tmp/connections.json\n"
"The --format parameter can be used to mention the connections format. eg, "
"The --file-format parameter can be used to control the file format. E.g., "
"the default format is JSON in STDOUT mode, which can be overridden using: \n"
"airflow connections export - --format yaml\n"
"The --format parameter can also be used for the files, for example:\n"
"airflow connections export /tmp/connections --format json\n"
"airflow connections export - --file-format yaml\n"
"The --file-format parameter can also be used for the files, for example:\n"
"airflow connections export /tmp/connections --file-format json.\n"
"When exporting in `env` file format, you control whether URI format or JSON format "
"is used to serialize the connection by passing `uri` or `json` with option "
"`--serialization-format`.\n"
),
func=lazy_load_command('airflow.cli.commands.connection_command.connections_export'),
args=(
ARG_CONN_EXPORT,
ARG_CONN_EXPORT_FORMAT,
ARG_CONN_EXPORT_FILE_FORMAT,
ARG_CONN_SERIALIZATION_FORMAT,
),
),
ActionCommand(
Expand Down
164 changes: 103 additions & 61 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import json
import os
import sys
import warnings
from pathlib import Path
from typing import Any, Dict, List
from urllib.parse import urlparse, urlunparse

Expand Down Expand Up @@ -82,30 +84,40 @@ def connections_list(args):
)


def _format_connections(conns: List[Connection], fmt: str) -> str:
if fmt == '.env':
def _connection_to_dict(conn: Connection) -> dict:
return dict(
conn_type=conn.conn_type,
description=conn.description,
login=conn.login,
password=conn.password,
host=conn.host,
port=conn.port,
schema=conn.schema,
extra=conn.extra,
)


def _format_connections(conns: List[Connection], file_format: str, serialization_format: str) -> str:
if serialization_format == 'json':
serializer_func = lambda x: json.dumps(_connection_to_dict(x))
elif serialization_format == 'uri':
serializer_func = Connection.get_uri
else:
raise SystemExit(f"Received unexpected value for `--serialization-format`: {serialization_format!r}")
if file_format == '.env':
connections_env = ""
for conn in conns:
connections_env += f"{conn.conn_id}={conn.get_uri()}\n"
connections_env += f"{conn.conn_id}={serializer_func(conn)}\n"
return connections_env

connections_dict = {}
for conn in conns:
connections_dict[conn.conn_id] = {
'conn_type': conn.conn_type,
'description': conn.description,
'host': conn.host,
'login': conn.login,
'password': conn.password,
'schema': conn.schema,
'port': conn.port,
'extra': conn.extra,
}

if fmt == '.yaml':
connections_dict[conn.conn_id] = _connection_to_dict(conn)

if file_format == '.yaml':
return yaml.dump(connections_dict)

if fmt == '.json':
if file_format == '.json':
return json.dumps(connections_dict, indent=2)

return json.dumps(connections_dict)
Expand All @@ -123,33 +135,48 @@ def _valid_uri(uri: str) -> bool:

def connections_export(args):
"""Exports all connections to a file"""
allowed_formats = ['.yaml', '.json', '.env']
provided_format = None if args.format is None else f".{args.format.lower()}"
default_format = provided_format or '.json'
file_formats = ['.yaml', '.json', '.env']
if args.format:
warnings.warn("Option `--format` is deprecated. Use `--file-format` instead.", DeprecationWarning)
if args.format and args.file_format:
raise SystemExit('Option `--format` is deprecated. Use `--file-format` instead.')
default_format = '.json'
provided_file_format = None
if args.format or args.file_format:
provided_file_format = f".{(args.format or args.file_format).lower()}"

file_is_stdout = _is_stdout(args.file)
if file_is_stdout:
filetype = provided_file_format or default_format
elif provided_file_format:
filetype = provided_file_format
else:
filetype = Path(args.file.name).suffix
filetype = filetype.lower()
if filetype not in file_formats:
raise SystemExit(
f"Unsupported file format. The file must have the extension {', '.join(file_formats)}."
)

with create_session() as session:
if _is_stdout(args.file):
filetype = default_format
elif provided_format is not None:
filetype = provided_format
else:
_, filetype = os.path.splitext(args.file.name)
filetype = filetype.lower()
if filetype not in allowed_formats:
raise SystemExit(
f"Unsupported file format. The file must have "
f"the extension {', '.join(allowed_formats)}."
)
if args.serialization_format and not filetype == '.env':
raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")

with create_session() as session:
connections = session.query(Connection).order_by(Connection.conn_id).all()
msg = _format_connections(connections, filetype)
args.file.write(msg)
args.file.close()

if _is_stdout(args.file):
print("Connections successfully exported.", file=sys.stderr)
else:
print(f"Connections successfully exported to {args.file.name}.")
msg = _format_connections(
conns=connections,
file_format=filetype,
serialization_format=args.serialization_format or 'uri',
)

with args.file as f:
f.write(msg)

if file_is_stdout:
print("\nConnections successfully exported.", file=sys.stderr)
else:
print(f"Connections successfully exported to {args.file.name}.")


alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port']
Expand All @@ -158,27 +185,42 @@ def connections_export(args):
@cli_utils.action_cli
def connections_add(args):
"""Adds new connection"""
# Check that the conn_id and conn_uri args were passed to the command:
missing_args = []
invalid_args = []
if args.conn_uri:
if not _valid_uri(args.conn_uri):
has_uri = bool(args.conn_uri)
has_json = bool(args.conn_json)
has_type = bool(args.conn_type)

if not has_type and not (has_json or has_uri):
raise SystemExit('Must supply either conn-uri or conn-json if not supplying conn-type')

if has_json and has_uri:
raise SystemExit('Cannot supply both conn-uri and conn-json')

if has_uri or has_json:
invalid_args = []
if has_uri and not _valid_uri(args.conn_uri):
raise SystemExit(f'The URI provided to --conn-uri is invalid: {args.conn_uri}')

for arg in alternative_conn_specs:
if getattr(args, arg) is not None:
invalid_args.append(arg)
elif not args.conn_type:
missing_args.append('conn-uri or conn-type')
if missing_args:
raise SystemExit(f'The following args are required to add a connection: {missing_args!r}')
if invalid_args:
raise SystemExit(
f'The following args are not compatible with the '
f'add flag and --conn-uri flag: {invalid_args!r}'
)

if has_json and args.conn_extra:
invalid_args.append("--conn-extra")

if invalid_args:
raise SystemExit(
"The following args are not compatible with "
f"the --conn-{'uri' if has_uri else 'json'} flag: {invalid_args!r}"
)

if args.conn_uri:
new_conn = Connection(conn_id=args.conn_id, description=args.conn_description, uri=args.conn_uri)
if args.conn_extra is not None:
new_conn.set_extra(args.conn_extra)
elif args.conn_json:
new_conn = Connection.from_json(conn_id=args.conn_id, value=args.conn_json)
if not new_conn.conn_type:
raise SystemExit('conn-json is invalid; must supply conn-type')
else:
new_conn = Connection(
conn_id=args.conn_id,
Expand All @@ -190,8 +232,8 @@ def connections_add(args):
schema=args.conn_schema,
port=args.conn_port,
)
if args.conn_extra is not None:
new_conn.set_extra(args.conn_extra)
if args.conn_extra is not None:
new_conn.set_extra(args.conn_extra)

with create_session() as session:
if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first():
Expand All @@ -202,14 +244,14 @@ def connections_add(args):
uri=args.conn_uri
or urlunparse(
(
args.conn_type,
new_conn.conn_type,
'{login}:{password}@{host}:{port}'.format(
login=args.conn_login or '',
password='******' if args.conn_password else '',
host=args.conn_host or '',
port=args.conn_port or '',
login=new_conn.login or '',
password='******' if new_conn.password else '',
host=new_conn.host or '',
port=new_conn.port or '',
),
args.conn_schema or '',
new_conn.schema or '',
'',
'',
'',
Expand Down
31 changes: 26 additions & 5 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,18 @@ def parse_from_uri(self, **uri):
)
self._parse_from_uri(**uri)

def _parse_from_uri(self, uri: str):
uri_parts = urlparse(uri)
conn_type = uri_parts.scheme
@staticmethod
def _normalize_conn_type(conn_type):
if conn_type == 'postgresql':
conn_type = 'postgres'
elif '-' in conn_type:
conn_type = conn_type.replace('-', '_')
self.conn_type = conn_type
return conn_type

def _parse_from_uri(self, uri: str):
uri_parts = urlparse(uri)
conn_type = uri_parts.scheme
self.conn_type = self._normalize_conn_type(conn_type)
self.host = _parse_netloc_to_hostname(uri_parts)
quoted_schema = uri_parts.path[1:]
self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
Expand Down Expand Up @@ -335,7 +339,7 @@ def get_hook(self, *, hook_params=None):
return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)

def __repr__(self):
return self.conn_id
return self.conn_id or ''

def log_info(self):
"""
Expand Down Expand Up @@ -424,3 +428,20 @@ def get_connection_from_secrets(cls, conn_id: str) -> 'Connection':
)

raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")

@classmethod
def from_json(cls, value, conn_id=None) -> 'Connection':
kwargs = json.loads(value)
extra = kwargs.pop('extra', None)
if extra:
kwargs['extra'] = extra if isinstance(extra, str) else json.dumps(extra)
conn_type = kwargs.pop('conn_type', None)
if conn_type:
kwargs['conn_type'] = cls._normalize_conn_type(conn_type)
port = kwargs.pop('port', None)
if port:
try:
kwargs['port'] = int(port)
except ValueError:
raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
return Connection(conn_id=conn_id, **kwargs)
Loading

0 comments on commit 85beb36

Please sign in to comment.