Skip to content

Commit

Permalink
Improve the credential check. (#183)
Browse files Browse the repository at this point in the history
### Description

Improves the credential check.

It will show a consistent error message when missing the required fields, `host`, `http_path`, and `token`.

For example:

```
The config 'host' is required to connect to Databricks
```
  • Loading branch information
ueshin authored Sep 23, 2022
1 parent 54b962e commit 202611f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
24 changes: 10 additions & 14 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Optional,
Sequence,
Tuple,
cast,
)

from agate import Table
Expand Down Expand Up @@ -57,14 +58,15 @@

@dataclass
class DatabricksCredentials(Credentials):
host: str
database: Optional[str] # type: ignore[assignment]
host: Optional[str] = None
http_path: Optional[str] = None
token: Optional[str] = None
connect_retries: int = 0
connect_timeout: int = 10
session_properties: Optional[Dict[str, Any]] = None
connection_parameters: Optional[Dict[str, Any]] = None

connect_retries: int = 0
connect_timeout: int = 10
retry_all: bool = False

_ALIASES = {
Expand Down Expand Up @@ -131,7 +133,7 @@ def type(self) -> str:

@property
def unique_field(self) -> str:
return self.host
return cast(str, self.host)

def connection_info(self, *, with_aliases: bool = False) -> Iterable[Tuple[str, Any]]:
as_dict = self.to_dict(omit_none=False)
Expand Down Expand Up @@ -403,7 +405,7 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table:
@classmethod
def validate_creds(cls, creds: DatabricksCredentials, required: List[str]) -> None:
for key in required:
if not hasattr(creds, key):
if not getattr(creds, key):
raise dbt.exceptions.DbtProfileError(
"The config '{}' is required to connect to Databricks".format(key)
)
Expand Down Expand Up @@ -450,7 +452,7 @@ def open(cls, connection: Connection) -> Connection:
return connection

creds: DatabricksCredentials = connection.credentials
exc: Optional[Exception] = None
cls.validate_creds(creds, ["host", "http_path", "token"])

user_agent_entry = f"dbt-databricks/{__version__}"

Expand All @@ -459,20 +461,14 @@ def open(cls, connection: Connection) -> Connection:
cls.validate_invocation_env(invocation_env)
user_agent_entry = f"{user_agent_entry}; {invocation_env}"

if creds.http_path is None:
raise dbt.exceptions.DbtProfileError(
"`http_path` must set when" " using the dbsql method to connect to Databricks"
)
required_fields = ["host", "http_path", "token"]

cls.validate_creds(creds, required_fields)

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: List[Tuple[str, str]] = cls.get_all_http_headers(
connection_parameters.pop("http_headers", {})
)

exc: Optional[Exception] = None

for i in range(1 + creds.connect_retries):
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
Expand Down
4 changes: 3 additions & 1 deletion dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def submit_python_job(
command_name += "-" + str(uuid.uuid1())

api_client = Api12Client(
host=credentials.host, token=cast(str, credentials.token), command_name=command_name
host=cast(str, credentials.host),
token=cast(str, credentials.token),
command_name=command_name,
)

try:
Expand Down

0 comments on commit 202611f

Please sign in to comment.