Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue#22846] allow option to encode or not encode UUID when uploading from Cassandra to GCS #23766

Merged
merged 5 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class CassandraToGCSOperator(BaseOperator):
:param query_timeout: (Optional) The amount of time, in seconds, used to execute the Cassandra query.
If not set, the timeout value will be set in Session.execute() by Cassandra driver.
If set to None, there is no timeout.
:param encode_uuid: (Optional) Option to encode UUID or not when upload from Cassandra to GCS.
Default is to encode UUID.
"""

template_fields: Sequence[str] = (
Expand All @@ -105,6 +107,7 @@ def __init__(
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
query_timeout: Union[float, None, NotSetType] = NOT_SET,
encode_uuid: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -120,6 +123,7 @@ def __init__(
self.gzip = gzip
self.impersonation_chain = impersonation_chain
self.query_timeout = query_timeout
self.encode_uuid = encode_uuid

# Default Cassandra to BigQuery type mapping
CQL_TYPE_MAP = {
Expand Down Expand Up @@ -256,13 +260,11 @@ def _upload_to_gcs(self, file_to_upload):
gzip=self.gzip,
)

@classmethod
def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]:
def generate_data_dict(self, names: Iterable[str], values: Any) -> Dict[str, Any]:
"""Generates data structure that will be stored as file in GCS."""
return {n: cls.convert_value(v) for n, v in zip(names, values)}
return {n: self.convert_value(v) for n, v in zip(names, values)}

@classmethod
def convert_value(cls, value: Optional[Any]) -> Optional[Any]:
def convert_value(self, value: Optional[Any]) -> Optional[Any]:
"""Convert value to BQ type."""
if not value:
return value
Expand All @@ -271,59 +273,58 @@ def convert_value(cls, value: Optional[Any]) -> Optional[Any]:
elif isinstance(value, bytes):
return b64encode(value).decode('ascii')
elif isinstance(value, UUID):
return b64encode(value.bytes).decode('ascii')
if self.encode_uuid:
return b64encode(value.bytes).decode('ascii')
else:
return str(value)
elif isinstance(value, (datetime, Date)):
return str(value)
elif isinstance(value, Decimal):
return float(value)
elif isinstance(value, Time):
return str(value).split('.')[0]
elif isinstance(value, (list, SortedSet)):
return cls.convert_array_types(value)
return self.convert_array_types(value)
elif hasattr(value, '_fields'):
return cls.convert_user_type(value)
return self.convert_user_type(value)
elif isinstance(value, tuple):
return cls.convert_tuple_type(value)
return self.convert_tuple_type(value)
elif isinstance(value, OrderedMapSerializedKey):
return cls.convert_map_type(value)
return self.convert_map_type(value)
else:
raise AirflowException('Unexpected value: ' + str(value))

@classmethod
def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]:
def convert_array_types(self, value: Union[List[Any], SortedSet]) -> List[Any]:
"""Maps convert_value over array."""
return [cls.convert_value(nested_value) for nested_value in value]
return [self.convert_value(nested_value) for nested_value in value]

@classmethod
def convert_user_type(cls, value: Any) -> Dict[str, Any]:
def convert_user_type(self, value: Any) -> Dict[str, Any]:
"""
Converts a user type to RECORD that contains n fields, where n is the
number of attributes. Each element in the user type class will be converted to its
corresponding data type in BQ.
"""
names = value._fields
values = [cls.convert_value(getattr(value, name)) for name in names]
return cls.generate_data_dict(names, values)
values = [self.convert_value(getattr(value, name)) for name in names]
return self.generate_data_dict(names, values)

@classmethod
def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]:
def convert_tuple_type(self, values: Tuple[Any]) -> Dict[str, Any]:
"""
Converts a tuple to RECORD that contains n fields, each will be converted
to its corresponding data type in bq and will be named 'field_<index>', where
index is determined by the order of the tuple elements defined in cassandra.
"""
names = ['field_' + str(i) for i in range(len(values))]
return cls.generate_data_dict(names, values)
return self.generate_data_dict(names, values)

@classmethod
def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]:
def convert_map_type(self, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]:
"""
Converts a map to a repeated RECORD that contains two fields: 'key' and 'value',
each will be converted to its corresponding data type in BQ.
"""
converted_map = []
for k, v in zip(value.keys(), value.values()):
converted_map.append({'key': cls.convert_value(k), 'value': cls.convert_value(v)})
converted_map.append({'key': self.convert_value(k), 'value': self.convert_value(v)})
return converted_map

@classmethod
Expand Down
22 changes: 16 additions & 6 deletions tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,28 @@
from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator

TMP_FILE_NAME = "temp-file"
TEST_BUCKET = "test-bucket"
SCHEMA = "schema.json"
FILENAME = "data.json"
CQL = "select * from keyspace1.table1"
TASK_ID = "test-cas-to-gcs"


class TestCassandraToGCS(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.NamedTemporaryFile")
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.GCSHook.upload")
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraHook")
def test_execute(self, mock_hook, mock_upload, mock_tempfile):
test_bucket = "test-bucket"
schema = "schema.json"
filename = "data.json"
test_bucket = TEST_BUCKET
schema = SCHEMA
filename = FILENAME
gzip = True
query_timeout = 20
mock_tempfile.return_value.name = TMP_FILE_NAME

operator = CassandraToGCSOperator(
task_id="test-cas-to-gcs",
cql="select * from keyspace1.table1",
task_id=TASK_ID,
cql=CQL,
bucket=test_bucket,
filename=filename,
schema_filename=schema,
Expand Down Expand Up @@ -70,7 +75,10 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile):
mock_upload.assert_has_calls([call_schema, call_data], any_order=True)

def test_convert_value(self):
op = CassandraToGCSOperator
op = CassandraToGCSOperator(task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME)
unencoded_uuid_op = CassandraToGCSOperator(
task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME, encode_uuid=False
)
assert op.convert_value(None) is None
assert op.convert_value(1) == 1
assert op.convert_value(1.0) == 1.0
Expand All @@ -95,6 +103,8 @@ def test_convert_value(self):
test_uuid = uuid.uuid4()
encoded_uuid = b64encode(test_uuid.bytes).decode("ascii")
assert op.convert_value(test_uuid) == encoded_uuid
unencoded_uuid = str(test_uuid)
assert unencoded_uuid_op.convert_value(test_uuid) == unencoded_uuid

byte_str = b"abc"
encoded_b = b64encode(byte_str).decode("ascii")
Expand Down