Skip to content

Commit

Permalink
feat/databricks#554/accept-dict-cluster-policy-definition
Browse files Browse the repository at this point in the history
  • Loading branch information
copdips committed Oct 6, 2022
1 parent 81c26a8 commit 8b6c913
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
20 changes: 18 additions & 2 deletions databricks_cli/cluster_policies/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json

from databricks_cli.sdk import PolicyService


class ClusterPolicyApi(object):
def __init__(self, api_client):
self.client = PolicyService(api_client)

@staticmethod
def format_policy_for_api(policy):
if isinstance(policy.get("definition"), dict):
policy["definition"] = json.dumps(policy["definition"])
return policy

def create_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/create', data=json)
return self.client.client.perform_query(
"POST",
"/policies/clusters/create",
data=ClusterPolicyApi.format_policy_for_api(json),
)

def edit_cluster_policy(self, json):
return self.client.client.perform_query('POST', '/policies/clusters/edit', data=json)
return self.client.client.perform_query(
"POST",
"/policies/clusters/edit",
data=ClusterPolicyApi.format_policy_for_api(json),
)

def delete_cluster_policy(self, policy_id):
return self.client.delete_policy(policy_id)
Expand Down
62 changes: 62 additions & 0 deletions tests/cluster_policies/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Databricks CLI
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"), except
# that the use of services to which certain application programming
# interfaces (each, an "API") connect requires that the user first obtain
# a license for the use of the APIs from Databricks, Inc. ("Databricks"),
# by creating an account at www.databricks.com and agreeing to either (a)
# the Community Edition Terms of Service, (b) the Databricks Terms of
# Service, or (c) another written agreement between Licensee and Databricks
# for the use of the APIs.
#
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import pytest

from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.mark.parametrize(
"policy, expected",
[
({"definition": "foo"}, {"definition": "foo"}),
({"definition": {"foo": "bar"}}, {"definition": '{"foo": "bar"}'}),
],
)
def test_format_policy_for_api(policy, expected):
result = ClusterPolicyApi.format_policy_for_api(policy)
assert result == expected


@pytest.mark.parametrize(
"fct_name, method, action",
[
("create_cluster_policy", "POST", "create"),
("edit_cluster_policy", "POST", "edit"),
],
)
@mock.patch(
"databricks_cli.cluster_policies.api.ClusterPolicyApi.format_policy_for_api"
)
def test_create_and_edit_cluster_policy(
mock_format_policy_for_api, fct_name, method, action, fixture_cluster_policies_api
):
mock_policy = mock.Mock()
getattr(fixture_cluster_policies_api, fct_name)(mock_policy)
mock_format_policy_for_api.assert_called_once_with(mock_policy)
fixture_cluster_policies_api.client.client.perform_query.assert_called_once_with(
method,
"/policies/clusters/{}".format(action),
data=mock_format_policy_for_api.return_value,
)
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
# limitations under the License.
import shutil
import tempfile

import mock
import pytest

import databricks_cli.configure.provider as provider
from databricks_cli.cluster_policies.api import ClusterPolicyApi


@pytest.fixture(autouse=True)
Expand All @@ -33,3 +36,12 @@ def mock_conf_dir():
provider._home = path
yield
shutil.rmtree(path)


@pytest.fixture()
def fixture_cluster_policies_api():
with mock.patch(
"databricks_cli.cluster_policies.api.PolicyService"
) as service_mock:
service_mock.return_value = mock.MagicMock()
yield ClusterPolicyApi(None)

0 comments on commit 8b6c913

Please sign in to comment.