diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index fa0a999a3..a691f2ec8 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -16,6 +16,7 @@ Optional, Sequence, Tuple, + cast, ) from agate import Table @@ -57,14 +58,15 @@ @dataclass class DatabricksCredentials(Credentials): - host: str database: Optional[str] # type: ignore[assignment] + host: Optional[str] = None http_path: Optional[str] = None token: Optional[str] = None - connect_retries: int = 0 - connect_timeout: int = 10 session_properties: Optional[Dict[str, Any]] = None connection_parameters: Optional[Dict[str, Any]] = None + + connect_retries: int = 0 + connect_timeout: int = 10 retry_all: bool = False _ALIASES = { @@ -131,7 +133,7 @@ def type(self) -> str: @property def unique_field(self) -> str: - return self.host + return cast(str, self.host) def connection_info(self, *, with_aliases: bool = False) -> Iterable[Tuple[str, Any]]: as_dict = self.to_dict(omit_none=False) @@ -403,7 +405,7 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table: @classmethod def validate_creds(cls, creds: DatabricksCredentials, required: List[str]) -> None: for key in required: - if not hasattr(creds, key): + if not getattr(creds, key): raise dbt.exceptions.DbtProfileError( "The config '{}' is required to connect to Databricks".format(key) ) @@ -450,7 +452,7 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials - exc: Optional[Exception] = None + cls.validate_creds(creds, ["host", "http_path", "token"]) user_agent_entry = f"dbt-databricks/{__version__}" @@ -459,20 +461,14 @@ def open(cls, connection: Connection) -> Connection: cls.validate_invocation_env(invocation_env) user_agent_entry = f"{user_agent_entry}; {invocation_env}" - if creds.http_path is None: - raise dbt.exceptions.DbtProfileError( - "`http_path` must set when" " using the dbsql method to connect to Databricks" - ) - required_fields = ["host", "http_path", "token"] - - cls.validate_creds(creds, required_fields) - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] http_headers: List[Tuple[str, str]] = cls.get_all_http_headers( connection_parameters.pop("http_headers", {}) ) + exc: Optional[Exception] = None + for i in range(1 + creds.connect_retries): try: # TODO: what is the error when a user specifies a catalog they don't have access to diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index ae522dd72..259ecc04d 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -288,7 +288,9 @@ def submit_python_job( command_name += "-" + str(uuid.uuid1()) api_client = Api12Client( - host=credentials.host, token=cast(str, credentials.token), command_name=command_name + host=cast(str, credentials.host), + token=cast(str, credentials.token), + command_name=command_name, ) try: