Skip to content

Commit

Permalink
Merge pull request #107 from phillyfan1138/master
Browse files Browse the repository at this point in the history
allow kerberos
  • Loading branch information
jtcohen6 authored Sep 25, 2020
2 parents 7b3ac5b + 9d03250 commit 5f19129
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
6 changes: 5 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class SparkCredentials(Credentials):
token: Optional[str] = None
user: Optional[str] = None
port: int = 443
auth: Optional[str]=None
kerberos_service_name: Optional[str]=None
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10
Expand Down Expand Up @@ -269,7 +271,9 @@ def open(cls, connection):

conn = hive.connect(host=creds.host,
port=creds.port,
username=creds.user)
username=creds.user,
auth=creds.auth,
kerberos_service_name=creds.kerberos_service_name)
else:
raise dbt.exceptions.DbtProfileError(
f"invalid credential method: {creds.method}"
Expand Down
40 changes: 39 additions & 1 deletion test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def _get_target_thrift(self, project):
'target': 'test'
})

def _get_target_thrift_kerberos(self, project):
return config_from_parts_or_dicts(project, {
'outputs': {
'test': {
'type': 'spark',
'method': 'thrift',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 10001,
'user': 'dbt',
'auth': 'KERBEROS',
'kerberos_service_name': 'hive'
}
},
'target': 'test'
})
def test_http_connection(self):
config = self._get_target_http(self.project_cfg)
adapter = SparkAdapter(config)
Expand All @@ -83,10 +99,32 @@ def test_thrift_connection(self):
config = self._get_target_thrift(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username):
def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, 'dbt')
self.assertIsNone(auth)
self.assertIsNone(kerberos_service_name)

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
connection.handle # trigger lazy-load

self.assertEqual(connection.state, 'open')
self.assertIsNotNone(connection.handle)
self.assertEqual(connection.credentials.schema, 'analytics')
self.assertIsNone(connection.credentials.database)

def test_thrift_connection_kerberos(self):
config = self._get_target_thrift_kerberos(self.project_cfg)
adapter = SparkAdapter(config)

def hive_thrift_connect(host, port, username, auth, kerberos_service_name):
self.assertEqual(host, 'myorg.sparkhost.com')
self.assertEqual(port, 10001)
self.assertEqual(username, 'dbt')
self.assertEqual(auth, 'KERBEROS')
self.assertEqual(kerberos_service_name, 'hive')

with mock.patch.object(hive, 'connect', new=hive_thrift_connect):
connection = adapter.acquire_connection('dummy')
Expand Down

0 comments on commit 5f19129

Please sign in to comment.