diff --git a/airflow/providers/tableau/hooks/tableau.py b/airflow/providers/tableau/hooks/tableau.py index 09cb607ec5456..e0f69595ebc1d 100644 --- a/airflow/providers/tableau/hooks/tableau.py +++ b/airflow/providers/tableau/hooks/tableau.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from distutils.util import strtobool from enum import Enum from typing import Any, Optional @@ -26,9 +27,7 @@ class TableauJobFinishCode(Enum): """ The finish code indicates the status of the job. - .. seealso:: https://help.tableau.com/current/api/rest_api/en-us/REST/rest_api_ref.htm#query_job - """ PENDING = -1 @@ -40,8 +39,7 @@ class TableauJobFinishCode(Enum): class TableauHook(BaseHook): """ Connects to the Tableau Server Instance and allows to communicate with it. - - .. seealso:: https://tableau.github.io/server-client-python/docs/ + .. see also:: https://tableau.github.io/server-client-python/docs/ :param site_id: The id of the site where the workbook belongs to. It will connect to the default site if you don't provide an id. @@ -61,7 +59,16 @@ def __init__(self, site_id: Optional[str] = None, tableau_conn_id: str = default self.tableau_conn_id = tableau_conn_id self.conn = self.get_connection(self.tableau_conn_id) self.site_id = site_id or self.conn.extra_dejson.get('site_id', '') - self.server = Server(self.conn.host, use_server_version=True) + self.server = Server(self.conn.host) + verify = self.conn.extra_dejson.get('verify', 'True') + try: + verify = bool(strtobool(verify)) + except ValueError: + pass + self.server.add_http_options( + options_dict={'verify': verify, 'cert': self.conn.extra_dejson.get('cert', None)} + ) + self.server.use_server_version() self.tableau_conn = None def __enter__(self): @@ -75,7 +82,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def get_conn(self) -> Auth.contextmgr: """ Signs in to the Tableau Server and automatically signs out if used as ContextManager. - :return: an authorized Tableau Server Context Manager object. :rtype: tableauserverclient.server.Auth.contextmgr """ @@ -102,8 +108,7 @@ def _auth_via_token(self) -> Auth.contextmgr: def get_all(self, resource_name: str) -> Pager: """ Get all items of the given resource. - - .. seealso:: https://tableau.github.io/server-client-python/docs/page-through-results + .. see also:: https://tableau.github.io/server-client-python/docs/page-through-results :param resource_name: The name of the resource to paginate. For example: jobs or workbooks diff --git a/docs/apache-airflow-providers-tableau/connections/tableau.rst b/docs/apache-airflow-providers-tableau/connections/tableau.rst index 226854538e833..7cd9adb4ce646 100644 --- a/docs/apache-airflow-providers-tableau/connections/tableau.rst +++ b/docs/apache-airflow-providers-tableau/connections/tableau.rst @@ -71,6 +71,9 @@ Extra (optional) This is used with token authentication. * ``personal_access_token``: The personal access token value. This is used with token authentication. + * ``verify``: Either a boolean, in which case it controls whether we verify the server’s TLS certificate, or a string, in which case it must be a path to a CA bundle to use. Defaults to True. + * ``cert``: if String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’) pair. + When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/tests/providers/tableau/hooks/test_tableau.py b/tests/providers/tableau/hooks/test_tableau.py index 4c8a6a643c024..f49d7707c5a6e 100644 --- a/tests/providers/tableau/hooks/test_tableau.py +++ b/tests/providers/tableau/hooks/test_tableau.py @@ -52,6 +52,34 @@ def setUp(self): extra='{"token_name": "my_token", "personal_access_token": "my_personal_access_token"}', ) ) + db.merge_conn( + models.Connection( + conn_id='tableau_test_ssl_connection_certificates_path', + conn_type='tableau', + host='tableau', + login='user', + password='password', + extra='{"verify": "my_cert_path", "cert": "my_client_cert_path"}', + ) + ) + db.merge_conn( + models.Connection( + conn_id='tableau_test_ssl_false_connection', + conn_type='tableau', + host='tableau', + login='user', + password='password', + extra='{"verify": "False"}', + ) + ) + db.merge_conn( + models.Connection( + conn_id='tableau_test_ssl_connection_default', + conn_type='tableau', + host='tableau', + extra='{"token_name": "my_token", "personal_access_token": "my_personal_access_token"}', + ) + ) @patch('airflow.providers.tableau.hooks.tableau.TableauAuth') @patch('airflow.providers.tableau.hooks.tableau.Server') @@ -60,7 +88,7 @@ def test_get_conn_auth_via_password_and_site_in_connection(self, mock_server, mo Test get conn auth via password """ with TableauHook(tableau_conn_id='tableau_test_password') as tableau_hook: - mock_server.assert_called_once_with(tableau_hook.conn.host, use_server_version=True) + mock_server.assert_called_once_with(tableau_hook.conn.host) mock_tableau_auth.assert_called_once_with( username=tableau_hook.conn.login, password=tableau_hook.conn.password, @@ -76,7 +104,7 @@ def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tablea Test get conn auth via token """ with TableauHook(site_id='test', tableau_conn_id='tableau_test_token') as tableau_hook: - mock_server.assert_called_once_with(tableau_hook.conn.host, use_server_version=True) + mock_server.assert_called_once_with(tableau_hook.conn.host) mock_tableau_auth.assert_called_once_with( token_name=tableau_hook.conn.extra_dejson['token_name'], personal_access_token=tableau_hook.conn.extra_dejson['personal_access_token'], @@ -87,6 +115,68 @@ def test_get_conn_auth_via_token_and_site_in_init(self, mock_server, mock_tablea ) mock_server.return_value.auth.sign_out.assert_called_once_with() + @patch('airflow.providers.tableau.hooks.tableau.TableauAuth') + @patch('airflow.providers.tableau.hooks.tableau.Server') + def test_get_conn_ssl_cert_path(self, mock_server, mock_tableau_auth): + """ + Test get conn with SSL parameters, verify as path + """ + with TableauHook(tableau_conn_id='tableau_test_ssl_connection_certificates_path') as tableau_hook: + mock_server.assert_called_once_with(tableau_hook.conn.host) + mock_server.return_value.add_http_options.assert_called_once_with( + options_dict={ + 'verify': tableau_hook.conn.extra_dejson['verify'], + 'cert': tableau_hook.conn.extra_dejson['cert'], + } + ) + mock_tableau_auth.assert_called_once_with( + username=tableau_hook.conn.login, + password=tableau_hook.conn.password, + site_id='', + ) + mock_server.return_value.auth.sign_in.assert_called_once_with(mock_tableau_auth.return_value) + mock_server.return_value.auth.sign_out.assert_called_once_with() + + @patch('airflow.providers.tableau.hooks.tableau.PersonalAccessTokenAuth') + @patch('airflow.providers.tableau.hooks.tableau.Server') + def test_get_conn_ssl_default(self, mock_server, mock_tableau_auth): + """ + Test get conn with default SSL parameters + """ + with TableauHook(tableau_conn_id='tableau_test_ssl_connection_default') as tableau_hook: + mock_server.assert_called_once_with(tableau_hook.conn.host) + mock_server.return_value.add_http_options.assert_called_once_with( + options_dict={'verify': True, 'cert': None} + ) + mock_tableau_auth.assert_called_once_with( + token_name=tableau_hook.conn.extra_dejson['token_name'], + personal_access_token=tableau_hook.conn.extra_dejson['personal_access_token'], + site_id='', + ) + mock_server.return_value.auth.sign_in_with_personal_access_token.assert_called_once_with( + mock_tableau_auth.return_value + ) + mock_server.return_value.auth.sign_out.assert_called_once_with() + + @patch('airflow.providers.tableau.hooks.tableau.TableauAuth') + @patch('airflow.providers.tableau.hooks.tableau.Server') + def test_get_conn_ssl_disabled(self, mock_server, mock_tableau_auth): + """ + Test get conn with default SSL disabled parameters + """ + with TableauHook(tableau_conn_id='tableau_test_ssl_false_connection') as tableau_hook: + mock_server.assert_called_once_with(tableau_hook.conn.host) + mock_server.return_value.add_http_options.assert_called_once_with( + options_dict={'verify': False, 'cert': None} + ) + mock_tableau_auth.assert_called_once_with( + username=tableau_hook.conn.login, + password=tableau_hook.conn.password, + site_id='', + ) + mock_server.return_value.auth.sign_in.assert_called_once_with(mock_tableau_auth.return_value) + mock_server.return_value.auth.sign_out.assert_called_once_with() + @patch('airflow.providers.tableau.hooks.tableau.TableauAuth') @patch('airflow.providers.tableau.hooks.tableau.Server') @patch('airflow.providers.tableau.hooks.tableau.Pager', return_value=[1, 2, 3])