Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Memcached hooks - improve protobuf messages handling #11743

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions airflow/providers/google/cloud/hooks/cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
"""Hooks for Cloud Memorystore service"""
from typing import Dict, Optional, Sequence, Tuple, Union
import json

from google.api_core.exceptions import NotFound
from google.api_core import path_template
Expand All @@ -28,7 +27,6 @@
from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
from google.cloud.redis_v1.types import FieldMask, InputConfig, Instance, OutputConfig
from google.protobuf.json_format import ParseDict
import proto

from airflow import version
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -569,11 +567,6 @@ def _append_label(instance: cloud_memcache.Instance, key: str, val: str) -> clou
instance.labels.update({key: val})
return instance

@staticmethod
def proto_message_to_dict(message: proto.Message) -> dict:
"""Helper method to parse protobuf message to dictionary."""
return json.loads(message.__class__.to_json(message))

@GoogleBaseHook.fallback_to_default_project_id
def apply_parameters(
self,
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ def execute(self, context: Dict):
timeout=self.timeout,
metadata=self.metadata,
)
return hook.proto_message_to_dict(result)
return cloud_memcache.Instance.to_dict(result)


class CloudMemorystoreMemcachedDeleteInstanceOperator(BaseOperator):
Expand Down Expand Up @@ -1438,7 +1438,7 @@ def execute(self, context: Dict):
timeout=self.timeout,
metadata=self.metadata,
)
return hook.proto_message_to_dict(result)
return cloud_memcache.Instance.to_dict(result)


class CloudMemorystoreMemcachedListInstancesOperator(BaseOperator):
Expand Down Expand Up @@ -1520,7 +1520,7 @@ def execute(self, context: Dict):
timeout=self.timeout,
metadata=self.metadata,
)
instances = [hook.proto_message_to_dict(a) for a in result]
instances = [cloud_memcache.Instance.to_dict(a) for a in result]
return instances


Expand Down
16 changes: 0 additions & 16 deletions tests/providers/google/cloud/hooks/test_cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,19 +599,3 @@ def test_update_instance(self, mock_get_conn, mock_project_id):
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)

def test_proto_functions(self):
instance_dict = {
'name': 'test_name',
'node_count': 1,
'node_config': {'cpu_count': 1, 'memory_size_mb': 1024},
}
instance = cloud_memcache.Instance(instance_dict)
instance_dict_result = self.hook.proto_message_to_dict(instance)
self.assertEqual(instance_dict_result["name"], instance_dict["name"])
self.assertEqual(
instance_dict_result["nodeConfig"]["cpuCount"], instance_dict["node_config"]["cpu_count"]
)
self.assertEqual(
instance_dict_result["nodeConfig"]["memorySizeMb"], instance_dict["node_config"]["memory_size_mb"]
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.api_core.retry import Retry
from google.cloud.redis_v1.gapic.enums import FailoverInstanceRequest
from google.cloud.redis_v1.types import Instance
from google.cloud.memcache_v1beta2.types import cloud_memcache

from airflow.providers.google.cloud.operators.cloud_memorystore import (
CloudMemorystoreCreateInstanceAndImportOperator,
Expand Down Expand Up @@ -386,6 +387,7 @@ def test_assert_valid_hook_call(self, mock_hook):
class TestCloudMemorystoreMemcachedCreateInstanceOperator(TestCase):
@mock.patch("airflow.providers.google.cloud.operators.cloud_memorystore.CloudMemorystoreMemcachedHook")
def test_assert_valid_hook_call(self, mock_hook):
mock_hook.return_value.create_instance.return_value = cloud_memcache.Instance()
task = CloudMemorystoreMemcachedCreateInstanceOperator(
task_id=TEST_TASK_ID,
location=TEST_LOCATION,
Expand Down Expand Up @@ -438,6 +440,7 @@ def test_assert_valid_hook_call(self, mock_hook):
class TestCloudMemorystoreMemcachedGetInstanceOperator(TestCase):
@mock.patch("airflow.providers.google.cloud.operators.cloud_memorystore.CloudMemorystoreMemcachedHook")
def test_assert_valid_hook_call(self, mock_hook):
mock_hook.return_value.get_instance.return_value = cloud_memcache.Instance()
task = CloudMemorystoreMemcachedGetInstanceOperator(
task_id=TEST_TASK_ID,
location=TEST_LOCATION,
Expand Down