Skip to content

Commit

Permalink
database: drop dbt implementation (#10222)
Browse files Browse the repository at this point in the history
* remove dbt implementation in the database

* lazy load client

* fixup

* fixup

* fix server side cursors

* use exec_driver_sql for testing connection
  • Loading branch information
skshetry authored Jan 5, 2024
1 parent 0100c14 commit 426a714
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 1,159 deletions.
137 changes: 16 additions & 121 deletions dvc/commands/imp_db.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,43 @@
import argparse

from funcy import compact, merge

from dvc.cli import completion
from dvc.cli.command import CmdBase
from dvc.cli.command import CmdBase, CmdBaseNoRepo
from dvc.cli.utils import append_doc_link
from dvc.log import logger
from dvc.ui import ui

logger = logger.getChild(__name__)


class CmdTestDb(CmdBase):
class CmdTestDb(CmdBaseNoRepo):
def run(self):
from dvc.database import get_client
from dvc.database.dbt_utils import DBT_PROJECT_FILE, is_dbt_project
from dvc.dependency.db import _get_dbt_config
from dvc.config import Config
from dvc.database import client
from dvc.exceptions import DvcException

connection = self.args.conn

db_config = self.repo.config.get("db", {})
config = db_config.get(connection, {})
if connection and not config:
db_config = Config.from_cwd().get("db", {})
if connection not in db_config:
raise DvcException(f"connection {connection} not found in config")

cli_config = compact(
{
"url": self.args.url,
"username": self.args.username,
"password": self.args.password,
}
)
conn_config = merge(config, cli_config)

cli_dbt_config = compact(
{"profile": self.args.profile, "target": self.args.target}
)
dbt_config = merge(_get_dbt_config(self.repo.config), cli_dbt_config)

project_dir = self.repo.root_dir
if not (conn_config or dbt_config):
if not self.args.dbt_conn:
raise DvcException(
"no config set; provide arguments or set a configuration"
)

if is_dbt_project(project_dir):
ui.write("Using", DBT_PROJECT_FILE, "for testing", styled=True)
else:
raise DvcException(
f"no config set and {DBT_PROJECT_FILE} is missing; "
"provide arguments or set a configuration"
)

adapter = get_client(conn_config, project_dir=project_dir, **dbt_config)
with adapter as db:
config = db_config.get(connection, {})
if self.args.url:
config["url"] = self.args.url
if self.args.username:
config["username"] = self.args.username
if self.args.password:
config["password"] = self.args.password
with client(config) as db:
ui.write(f"Testing with {db}", styled=True)

creds = getattr(db, "creds", {})
for k, v in creds.items():
ui.write("\t", f"{k}:", v, styled=True)

if creds:
ui.write()

db.test_connection()

ui.write("Connection successful", styled=True)


class CmdImportDb(CmdBase):
def run(self):
from dvc.exceptions import InvalidArgumentError

if self.args.table or self.args.sql:
arg = "--table" if self.args.table else "--sql"
options = {
"url": self.args.url,
"rev": self.args.rev,
"project_dir": self.args.project_dir,
}
opt = next((o for o, v in options.items() if v), None)
if opt:
raise InvalidArgumentError(f"argument {opt}: not allowed with {arg}")

if not self.args.conn and not self.args.dbt_conn:
raise InvalidArgumentError(f"{arg} requires --conn")
if self.args.model and self.args.conn:
raise InvalidArgumentError("argument --model: not allowed with --conn")

self.repo.imp_db(
url=self.args.url,
rev=self.args.rev,
project_dir=self.args.project_dir,
sql=self.args.sql,
table=self.args.table,
model=self.args.model,
profile=self.args.profile,
target=self.args.target,
output_format=self.args.output_format,
out=self.args.out,
force=self.args.force,
Expand All @@ -116,43 +55,9 @@ def add_parser(subparsers, parent_parser):
help=IMPORT_HELP,
formatter_class=argparse.RawTextHelpFormatter,
)
import_parser.add_argument(
"--url",
help=argparse.SUPPRESS,
# help="Location of dbt repository",
)
import_parser.add_argument(
"--rev",
nargs="?",
help=argparse.SUPPRESS,
# help="Git revision (e.g. SHA, branch, tag)",
metavar="<commit>",
)
import_parser.add_argument(
"--project-dir",
nargs="?",
help=argparse.SUPPRESS,
# help="Subdirectory to the dbt project location",
)

group = import_parser.add_mutually_exclusive_group(required=True)
group.add_argument("--sql", help="SQL query to snapshot")
group.add_argument("--table", help="Table to snapshot")
group.add_argument(
"--model",
help=argparse.SUPPRESS,
# help="Model name to download",
)
import_parser.add_argument(
"--profile",
help=argparse.SUPPRESS,
# help="Profile to use",
)
import_parser.add_argument(
"--target",
help=argparse.SUPPRESS,
# help="Target to use",
)
import_parser.add_argument(
"--output-format",
default="csv",
Expand All @@ -177,16 +82,9 @@ def add_parser(subparsers, parent_parser):
)
import_parser.add_argument(
"--conn",
nargs="?",
required=True,
help="Database connection to use, needs to be set in config",
)
import_parser.add_argument(
"--dbt-conn",
action="store_true",
default=False,
help=argparse.SUPPRESS,
# help="Use dbt connection",
)

import_parser.set_defaults(func=CmdImportDb)

Expand All @@ -197,11 +95,8 @@ def add_parser(subparsers, parent_parser):
description=append_doc_link(TEST_DB_HELP, "test-db"),
add_help=False,
)
test_db_parser.add_argument("--conn")
test_db_parser.add_argument("--dbt-conn", action="store_true", default=False)
test_db_parser.add_argument("--conn", required=True)
test_db_parser.add_argument("--url")
test_db_parser.add_argument("--password")
test_db_parser.add_argument("--username")
test_db_parser.add_argument("--profile")
test_db_parser.add_argument("--target")
test_db_parser.set_defaults(func=CmdTestDb)
2 changes: 0 additions & 2 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ def __call__(self, data):
"feature": FeatureSchema(
{
Optional("machine", default=False): Bool,
"dbt_profile": str,
"dbt_target": str,
},
),
"plots": {
Expand Down
124 changes: 124 additions & 0 deletions dvc/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Union

from sqlalchemy import create_engine
from sqlalchemy.engine import make_url as _make_url
from sqlalchemy.exc import NoSuchModuleError

from dvc import env
from dvc.exceptions import DvcException
from dvc.log import logger
from dvc.types import StrOrBytesPath
from dvc.utils import env2bool

if TYPE_CHECKING:
from sqlalchemy.engine import URL, Connectable, Engine
from sqlalchemy.sql.expression import Selectable


logger = logger.getChild(__name__)


def noop(_):
pass


def make_url(url: Union["URL", str], **kwargs: Any) -> "URL":
return _make_url(url).set(**kwargs)


def url_from_config(config: Union[str, "URL", Dict[str, str]]) -> "URL":
if isinstance(config, dict):
return make_url(**config)
return make_url(config)


@dataclass
class Serializer:
sql: "Union[str, Selectable]"
con: "Union[str, Connectable]"
chunksize: int = 10_000

def to_csv(self, file: StrOrBytesPath, progress=noop):
import pandas as pd

with open(file, mode="wb") as f:
idfs = pd.read_sql(self.sql, self.con, chunksize=self.chunksize)
for i, df in enumerate(idfs):
df.to_csv(f, header=i == 0, index=False)
progress(len(df))

def to_json(self, file: StrOrBytesPath, progress=noop): # noqa: ARG002
import pandas as pd

path = os.fsdecode(file)
df = pd.read_sql(self.sql, self.con)
df.to_json(path, orient="records")

def export(self, file: StrOrBytesPath, format: str = "csv", progress=noop): # noqa: A002
if format == "json":
return self.to_json(file, progress=progress)
return self.to_csv(file, progress=progress)


@dataclass
class Client:
engine: "Engine"

def test_connection(self, onerror: Optional[Callable[[], Any]] = None) -> None:
try:
with self.engine.connect() as conn:
conn.exec_driver_sql("select 1")
except Exception as exc:
if callable(onerror):
onerror()
logger.exception(
"Could not connect to the database. "
"Check your database credentials and try again.",
exc_info=False,
)
raise DvcException("The database returned the following error") from exc

def export(
self,
sql: "Union[str, Selectable]",
file: StrOrBytesPath,
format: str = "csv", # noqa: A002
progress=noop,
) -> None:
con = self.engine.connect()
if format == "csv":
con = con.execution_options(stream_results=True) # use server-side cursors

with con:
serializer = Serializer(sql, con)
return serializer.export(file, format=format, progress=progress)


@contextmanager
def handle_error(url: "URL"):
try:
yield
except (ModuleNotFoundError, NoSuchModuleError) as e:
# TODO: write installation instructions
driver = url.drivername
raise DvcException(f"Could not load database driver for {driver!r}") from e


@contextmanager
def client(
url_or_config: Union[str, "URL", Dict[str, str]], **engine_kwargs: Any
) -> Iterator[Client]:
url = url_from_config(url_or_config)
echo = env2bool(env.DVC_SQLALCHEMY_ECHO, False)
engine_kwargs.setdefault("echo", echo)

with handle_error(url):
engine = create_engine(url, **engine_kwargs)

try:
yield Client(engine)
finally:
engine.dispose()
22 changes: 0 additions & 22 deletions dvc/database/__init__.py

This file was deleted.

40 changes: 0 additions & 40 deletions dvc/database/dbt_models.py

This file was deleted.

Loading

0 comments on commit 426a714

Please sign in to comment.