Skip to content

Commit

Permalink
feat/#554/accept-dict-cluster-policy-definition
Browse files Browse the repository at this point in the history
  • Loading branch information
copdips committed Sep 2, 2022
1 parent 64c370f commit 81f0783
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 4 deletions.
32 changes: 28 additions & 4 deletions databricks_cli/cluster_policies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

from databricks_cli.sdk import PolicyService


class ClusterPolicyApi(object):
def __init__(self, api_client):
self.client = PolicyService(api_client)

def create_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/create', data=json)
@staticmethod
def convert_to_json_string(data):
if isinstance(data, str):
return data
if isinstance(data, dict):
return json.dumps(data)
raise TypeError(
"Only str or dict type are accepted, but got: {}.".format(type(data))
)

@staticmethod
def format_policy_for_api(policy):
policy["definition"] = ClusterPolicyApi.convert_to_json_string(
policy["definition"]
)

def create_cluster_policy(self, policy):
ClusterPolicyApi.format_policy_for_api(policy)
return self.client.client.perform_query(
"POST", "/policies/clusters/create", data=policy
)

def edit_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json)
def edit_cluster_policy(self, policy):
ClusterPolicyApi.format_policy_for_api(policy)
return self.client.client.perform_query(
"POST", "/policies/clusters/edit", data=policy
)

def delete_cluster_policy(self, policy_id):
return self.client.delete_policy(policy_id)
Expand Down
75 changes: 75 additions & 0 deletions tests/cluster_policies/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Databricks CLI
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"), except
# that the use of services to which certain application programming
# interfaces (each, an "API") connect requires that the user first obtain
# a license for the use of the APIs from Databricks, Inc. ("Databricks"),
# by creating an account at www.databricks.com and agreeing to either (a)
# the Community Edition Terms of Service, (b) the Databricks Terms of
# Service, or (c) another written agreement between Licensee and Databricks
# for the use of the APIs.
#
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import pytest

from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.mark.parametrize(
"data, expected",
[
("foo", "foo"),
({"foo": "bar"}, '{"foo": "bar"}'),
(1, None),
],
)
def test_convert_to_json_string(data, expected):
if expected:
assert expected == ClusterPolicyApi.convert_to_json_string(data)
else:
with pytest.raises(TypeError):
ClusterPolicyApi.convert_to_json_string(data)


@mock.patch(
"databricks_cli.cluster_policies.api.ClusterPolicyApi.convert_to_json_string"
)
def test_format_policy_for_api(mock_convert_to_json_string):
policy_definition = {"1": "2"}
policy = {"foo": "bar", "definition": policy_definition}
ClusterPolicyApi.format_policy_for_api(policy)
assert policy["definition"] == mock_convert_to_json_string(policy_definition)
assert policy["foo"] == "bar"


@pytest.mark.parametrize(
"fct_name, method, action",
[
("create_cluster_policy", "POST", "create"),
("edit_cluster_policy", "POST", "edit"),
],
)
@mock.patch(
"databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api"
)
def test_create_and_edit_cluster_policy(
mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api
):
mock_policy = mock.Mock()
getattr(fixture_cluster_policies_api, fct_name)(mock_policy)
mock_format_policy_for_api.assert_called_once_with(mock_policy)
fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with(
method, "/policies/clusters/{}".format(action), data=mock_policy
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
# limitations under the License.
import shutil
import tempfile

import mock
import pytest

import databricks_cli.configure.provider as provider
from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.fixture(autouse=True)
Expand All @@ -33,3 +36,12 @@ def mock_conf_dir():
provider._home = path
yield
shutil.rmtree(path)


@pytest.fixture()
def fixture_cluster_policies_api():
with mock.patch(
"databricks_cli.cluster_policies.api.PolicyService"
) as service_mock:
service_mock.return_value = mock.MagicMock()
yield ClusterPolicyApi(None)

0 comments on commit 81f0783

Please sign in to comment.