Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KED-2865] Make sql datasets use a singleton pattern for connection #1163

Merged
merged 19 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* `pipeline` now accepts `tags` and a collection of `Node`s and/or `Pipeline`s rather than just a single `Pipeline` object. `pipeline` should be used in preference to `Pipeline` when creating a Kedro pipeline.
* `pandas.SQLTableDataSet` and `pandas.SQLQueryDataSet` now only open one connection per database, at instantiation time (therefore at catalog creation time), rather than one per load/save operation.

## Bug fixes and other changes
* Added tutorial documentation for experiment tracking (`03_tutorial/07_set_up_experiment_tracking.md`).
Expand Down
98 changes: 63 additions & 35 deletions kedro/extras/datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ class SQLTableDataSet(AbstractDataSet):

"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]
DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {"index": False}
# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__(
self,
Expand Down Expand Up @@ -207,42 +210,50 @@ def __init__(
self._load_args["table_name"] = table_name
self._save_args["name"] = table_name

self._load_args["con"] = self._save_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLTableDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
save_args = self._save_args.copy()
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
del load_args["con"]
del save_args["name"]
del save_args["con"]
return dict(
table_name=self._load_args["table_name"],
load_args=load_args,
save_args=save_args,
)

def _load(self) -> pd.DataFrame:
try:
return pd.read_sql_table(**self._load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
try:
data.to_sql(**self._save_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines[self._connection_str] # type: ignore
data.to_sql(con=engine, **self._save_args)

def _exists(self) -> bool:
eng = create_engine(self._load_args["con"])
eng = self.engines[self._connection_str] # type: ignore
schema = self._load_args.get("schema", None)
exists = self._load_args["table_name"] in eng.table_names(schema)
eng.dispose()
return exists


Expand Down Expand Up @@ -299,6 +310,10 @@ class SQLQueryDataSet(AbstractDataSet):

"""

# using Any because of Sphinx but it should be
# sqlalchemy.engine.Engine or sqlalchemy.engine.base.Engine
engines: Dict[str, Any] = {}

def __init__( # pylint: disable=too-many-arguments
self,
sql: str = None,
Expand Down Expand Up @@ -374,32 +389,45 @@ def __init__( # pylint: disable=too-many-arguments
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._load_args["con"] = credentials["con"]
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if connection_str in cls.engines:
return

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

cls.engines[connection_str] = engine

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
desc = {}
desc["sql"] = str(load_args.pop("sql", None))
desc["filepath"] = str(self._filepath)
del load_args["con"]
desc["load_args"] = str(load_args)

return desc
return dict(
sql=str(load_args.pop("sql", None)),
filepath=str(self._filepath),
load_args=str(load_args),
)
lorenabalan marked this conversation as resolved.
Show resolved Hide resolved

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
engine = self.engines[self._connection_str] # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

try:
return pd.read_sql_query(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
return pd.read_sql_query(con=engine, **load_args)
lorenabalan marked this conversation as resolved.
Show resolved Hide resolved

def _save(self, data: pd.DataFrame) -> None:
raise DataSetError("`save` is not supported on SQLQueryDataSet")
Loading