diff --git a/databricks_cli/cluster_policies/api.py b/databricks_cli/cluster_policies/api.py index ad87c77e..34638a11 100644 --- a/databricks_cli/cluster_policies/api.py +++ b/databricks_cli/cluster_policies/api.py @@ -20,6 +20,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from json import dumps as json_dumps + from databricks_cli.sdk import PolicyService @@ -27,11 +29,25 @@ class ClusterPolicyApi(object): def __init__(self, api_client): self.client = PolicyService(api_client) + @staticmethod + def format_policy_for_api(policy): + if isinstance(policy.get("definition"), dict): + policy["definition"] = json_dumps(policy["definition"]) + return policy + def create_cluster_policy(self, json): - return self.client.client.perform_query('POST', '/policies/clusters/create', data=json) + return self.client.client.perform_query( + "POST", + "/policies/clusters/create", + data=ClusterPolicyApi.format_policy_for_api(json), + ) def edit_cluster_policy(self, json): - return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json) + return self.client.client.perform_query( + "POST", + "/policies/clusters/edit", + data=ClusterPolicyApi.format_policy_for_api(json), + ) def delete_cluster_policy(self, policy_id): return self.client.delete_policy(policy_id) diff --git a/tests/cluster_policies/test_api.py b/tests/cluster_policies/test_api.py new file mode 100644 index 00000000..ed5082fa --- /dev/null +++ b/tests/cluster_policies/test_api.py @@ -0,0 +1,62 @@ +# Databricks CLI +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import pytest + +from databricks_cli.cluster_policies.api import ClusterPolicyApi + + +@pytest.mark.parametrize( + "policy, expected", + [ + ({"definition": "foo"}, {"definition": "foo"}), + ({"definition": {"foo": "bar"}}, {"definition": '{"foo": "bar"}'}), + ], +) +def test_format_policy_for_api(policy, expected): + result = ClusterPolicyApi.format_policy_for_api(policy) + assert result == expected + + +@pytest.mark.parametrize( + "fct_name, method, action", + [ + ("create_cluster_policy", "POST", "create"), + ("edit_cluster_policy", "POST", "edit"), + ], +) +@mock.patch( + "databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api" +) +def test_create_and_edit_cluster_policy( + mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api +): + mock_policy = mock.Mock() + getattr(fixture_cluster_policies_api, fct_name)(mock_policy) + mock_format_policy_for_api.assert_called_once_with(mock_policy) + fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with( + method, + "/policies/clusters/{}".format(action), + data=mock_format_policy_for_api.return_value, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 6d7fe245..523bc2e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,9 +22,12 @@ # limitations under the License. import shutil import tempfile + +import mock import pytest import databricks_cli.configure.provider as provider +from databricks_cli.cluster_policies.api import ClusterPolicyApi @pytest.fixture(autouse=True) @@ -33,3 +36,12 @@ def mock_conf_dir(): provider._home = path yield shutil.rmtree(path) + + +@pytest.fixture() +def fixture_cluster_policies_api(): + with mock.patch( + "databricks_cli.cluster_policies.api.PolicyService" + ) as service_mock: + service_mock.return_value = mock.MagicMock() + yield ClusterPolicyApi(None)