Skip to content

Commit

Permalink
[Issue#22846] allow option to encode or not encode UUID when uploadin…
Browse files Browse the repository at this point in the history
…g from Cassandra to GCS (#23766)
  • Loading branch information
fuxiao224 authored May 20, 2022
1 parent 8494fc7 commit 5bfacf8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 29 deletions.
47 changes: 24 additions & 23 deletions airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class CassandraToGCSOperator(BaseOperator):
:param query_timeout: (Optional) The amount of time, in seconds, used to execute the Cassandra query.
If not set, the timeout value will be set in Session.execute() by Cassandra driver.
If set to None, there is no timeout.
:param encode_uuid: (Optional) Option to encode UUID or not when upload from Cassandra to GCS.
Default is to encode UUID.
"""

template_fields: Sequence[str] = (
Expand All @@ -105,6 +107,7 @@ def __init__(
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
query_timeout: Union[float, None, NotSetType] = NOT_SET,
encode_uuid: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -120,6 +123,7 @@ def __init__(
self.gzip = gzip
self.impersonation_chain = impersonation_chain
self.query_timeout = query_timeout
self.encode_uuid = encode_uuid

# Default Cassandra to BigQuery type mapping
CQL_TYPE_MAP = {
Expand Down Expand Up @@ -256,13 +260,11 @@ def _upload_to_gcs(self, file_to_upload):
gzip=self.gzip,
)

@classmethod
def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]:
def generate_data_dict(self, names: Iterable[str], values: Any) -> Dict[str, Any]:
"""Generates data structure that will be stored as file in GCS."""
return {n: cls.convert_value(v) for n, v in zip(names, values)}
return {n: self.convert_value(v) for n, v in zip(names, values)}

@classmethod
def convert_value(cls, value: Optional[Any]) -> Optional[Any]:
def convert_value(self, value: Optional[Any]) -> Optional[Any]:
"""Convert value to BQ type."""
if not value:
return value
Expand All @@ -271,59 +273,58 @@ def convert_value(cls, value: Optional[Any]) -> Optional[Any]:
elif isinstance(value, bytes):
return b64encode(value).decode('ascii')
elif isinstance(value, UUID):
return b64encode(value.bytes).decode('ascii')
if self.encode_uuid:
return b64encode(value.bytes).decode('ascii')
else:
return str(value)
elif isinstance(value, (datetime, Date)):
return str(value)
elif isinstance(value, Decimal):
return float(value)
elif isinstance(value, Time):
return str(value).split('.')[0]
elif isinstance(value, (list, SortedSet)):
return cls.convert_array_types(value)
return self.convert_array_types(value)
elif hasattr(value, '_fields'):
return cls.convert_user_type(value)
return self.convert_user_type(value)
elif isinstance(value, tuple):
return cls.convert_tuple_type(value)
return self.convert_tuple_type(value)
elif isinstance(value, OrderedMapSerializedKey):
return cls.convert_map_type(value)
return self.convert_map_type(value)
else:
raise AirflowException('Unexpected value: ' + str(value))

@classmethod
def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]:
def convert_array_types(self, value: Union[List[Any], SortedSet]) -> List[Any]:
"""Maps convert_value over array."""
return [cls.convert_value(nested_value) for nested_value in value]
return [self.convert_value(nested_value) for nested_value in value]

@classmethod
def convert_user_type(cls, value: Any) -> Dict[str, Any]:
def convert_user_type(self, value: Any) -> Dict[str, Any]:
"""
Converts a user type to RECORD that contains n fields, where n is the
number of attributes. Each element in the user type class will be converted to its
corresponding data type in BQ.
"""
names = value._fields
values = [cls.convert_value(getattr(value, name)) for name in names]
return cls.generate_data_dict(names, values)
values = [self.convert_value(getattr(value, name)) for name in names]
return self.generate_data_dict(names, values)

@classmethod
def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]:
def convert_tuple_type(self, values: Tuple[Any]) -> Dict[str, Any]:
"""
Converts a tuple to RECORD that contains n fields, each will be converted
to its corresponding data type in bq and will be named 'field_<index>', where
index is determined by the order of the tuple elements defined in cassandra.
"""
names = ['field_' + str(i) for i in range(len(values))]
return cls.generate_data_dict(names, values)
return self.generate_data_dict(names, values)

@classmethod
def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]:
def convert_map_type(self, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]:
"""
Converts a map to a repeated RECORD that contains two fields: 'key' and 'value',
each will be converted to its corresponding data type in BQ.
"""
converted_map = []
for k, v in zip(value.keys(), value.values()):
converted_map.append({'key': cls.convert_value(k), 'value': cls.convert_value(v)})
converted_map.append({'key': self.convert_value(k), 'value': self.convert_value(v)})
return converted_map

@classmethod
Expand Down
22 changes: 16 additions & 6 deletions tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,28 @@
from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator

TMP_FILE_NAME = "temp-file"
TEST_BUCKET = "test-bucket"
SCHEMA = "schema.json"
FILENAME = "data.json"
CQL = "select * from keyspace1.table1"
TASK_ID = "test-cas-to-gcs"


class TestCassandraToGCS(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.NamedTemporaryFile")
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.GCSHook.upload")
@mock.patch("airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraHook")
def test_execute(self, mock_hook, mock_upload, mock_tempfile):
test_bucket = "test-bucket"
schema = "schema.json"
filename = "data.json"
test_bucket = TEST_BUCKET
schema = SCHEMA
filename = FILENAME
gzip = True
query_timeout = 20
mock_tempfile.return_value.name = TMP_FILE_NAME

operator = CassandraToGCSOperator(
task_id="test-cas-to-gcs",
cql="select * from keyspace1.table1",
task_id=TASK_ID,
cql=CQL,
bucket=test_bucket,
filename=filename,
schema_filename=schema,
Expand Down Expand Up @@ -70,7 +75,10 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile):
mock_upload.assert_has_calls([call_schema, call_data], any_order=True)

def test_convert_value(self):
op = CassandraToGCSOperator
op = CassandraToGCSOperator(task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME)
unencoded_uuid_op = CassandraToGCSOperator(
task_id=TASK_ID, bucket=TEST_BUCKET, cql=CQL, filename=FILENAME, encode_uuid=False
)
assert op.convert_value(None) is None
assert op.convert_value(1) == 1
assert op.convert_value(1.0) == 1.0
Expand All @@ -95,6 +103,8 @@ def test_convert_value(self):
test_uuid = uuid.uuid4()
encoded_uuid = b64encode(test_uuid.bytes).decode("ascii")
assert op.convert_value(test_uuid) == encoded_uuid
unencoded_uuid = str(test_uuid)
assert unencoded_uuid_op.convert_value(test_uuid) == unencoded_uuid

byte_str = b"abc"
encoded_b = b64encode(byte_str).decode("ascii")
Expand Down

0 comments on commit 5bfacf8

Please sign in to comment.