From 76af90b1890a2db7947c00d37eabdc2ecfc88226 Mon Sep 17 00:00:00 2001 From: Xiang ZHU Date: Fri, 2 Sep 2022 23:37:02 +0200 Subject: [PATCH] feat/#554/accept-dict-cluster-policy-definition --- databricks_cli/cluster_policies/api.py | 32 +++++++++-- tests/cluster_policies/test_api.py | 75 ++++++++++++++++++++++++++ tests/conftest.py | 12 +++++ 3 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 tests/cluster_policies/test_api.py diff --git a/databricks_cli/cluster_policies/api.py b/databricks_cli/cluster_policies/api.py index ad87c77e..bdb06971 100644 --- a/databricks_cli/cluster_policies/api.py +++ b/databricks_cli/cluster_policies/api.py @@ -20,6 +20,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json + from databricks_cli.sdk import PolicyService @@ -27,11 +29,33 @@ class ClusterPolicyApi(object): def __init__(self, api_client): self.client = PolicyService(api_client) - def create_cluster_policy(self, json): - return self.client.client.perform_query('POST', '/policies/clusters/create', data=json) + @staticmethod + def convert_to_json_string(data): + if isinstance(data, str): + return data + if isinstance(data, dict): + return json.dumps(data) + raise TypeError( + "Only str or dict type are accepted, but got: {}.".format(type(data)) + ) + + @staticmethod + def format_policy_for_api(policy): + policy["definition"] = ClusterPolicyApi.convert_to_json_string( + policy["definition"] + ) + + def create_cluster_policy(self, policy): + ClusterPolicyApi.format_policy_for_api(policy) + return self.client.client.perform_query( + "POST", "/policies/clusters/create", data=policy + ) - def edit_cluster_policy(self, json): - return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json) + def edit_cluster_policy(self, policy): + ClusterPolicyApi.format_policy_for_api(policy) + return self.client.client.perform_query( + "POST", "/policies/clusters/edit", data=policy + ) def delete_cluster_policy(self, policy_id): return self.client.delete_policy(policy_id) diff --git a/tests/cluster_policies/test_api.py b/tests/cluster_policies/test_api.py new file mode 100644 index 00000000..cb4bf5f0 --- /dev/null +++ b/tests/cluster_policies/test_api.py @@ -0,0 +1,75 @@ +# Databricks CLI +# Copyright 2017 Databricks, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"), except +# that the use of services to which certain application programming +# interfaces (each, an "API") connect requires that the user first obtain +# a license for the use of the APIs from Databricks, Inc. ("Databricks"), +# by creating an account at www.databricks.com and agreeing to either (a) +# the Community Edition Terms of Service, (b) the Databricks Terms of +# Service, or (c) another written agreement between Licensee and Databricks +# for the use of the APIs. +# +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import pytest + +from databricks_cli.cluster_policies.api import ClusterPolicyApi + + +@pytest.mark.parametrize( + "data, expected", + [ + ("foo", "foo"), + ({"foo": "bar"}, '{"foo": "bar"}'), + (1, None), + ], +) +def test_convert_to_json_string(data, expected): + if expected: + assert expected == ClusterPolicyApi.convert_to_json_string(data) + else: + with pytest.raises(TypeError): + ClusterPolicyApi.convert_to_json_string(data) + + +@mock.patch( + "databricks_cli.cluster_policies.api.ClusterPolicyApi.convert_to_json_string" +) +def test_format_policy_for_api(mock_convert_to_json_string): + policy_definition = {"1": "2"} + policy = {"foo": "bar", "definition": policy_definition} + ClusterPolicyApi.format_policy_for_api(policy) + assert policy["definition"] == mock_convert_to_json_string(policy_definition) + assert policy["foo"] == "bar" + + +@pytest.mark.parametrize( + "fct_name, method, action", + [ + ("create_cluster_policy", "POST", "create"), + ("edit_cluster_policy", "POST", "edit"), + ], +) +@mock.patch( + "databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api" +) +def test_create_cluster_policy( + mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api +): + mock_policy = mock.Mock() + getattr(fixture_cluster_policies_api, fct_name)(mock_policy) + mock_format_policy_for_api.assert_called_once_with(mock_policy) + fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with( + method, "/policies/clusters/{}".format(action), data=mock_policy + ) diff --git a/tests/conftest.py b/tests/conftest.py index 6d7fe245..523bc2e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,9 +22,12 @@ # limitations under the License. import shutil import tempfile + +import mock import pytest import databricks_cli.configure.provider as provider +from databricks_cli.cluster_policies.api import ClusterPolicyApi @pytest.fixture(autouse=True) @@ -33,3 +36,12 @@ def mock_conf_dir(): provider._home = path yield shutil.rmtree(path) + + +@pytest.fixture() +def fixture_cluster_policies_api(): + with mock.patch( + "databricks_cli.cluster_policies.api.PolicyService" + ) as service_mock: + service_mock.return_value = mock.MagicMock() + yield ClusterPolicyApi(None)