From 3b9724a0ba7f45c9661831ecd4d9151ae23ce5d9 Mon Sep 17 00:00:00 2001 From: A <5249513+Dumeng@users.noreply.github.com> Date: Wed, 31 Jan 2024 02:04:56 +0800 Subject: [PATCH] feat: add support for proto3 optional tag (#727) * feat: add support for proto3 optional tag * format writer.py * Add the same changes to v1beta2 * Add systen test for proto3 support * Remove v1beta2 modifications * Fix issue in the test script and reformat * fix lint * Fix typo in the test * Remove unneed offset --------- Co-authored-by: Lingqing Gan --- google/cloud/bigquery_storage_v1/writer.py | 11 +++ tests/system/test_writer.py | 83 ++++++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/google/cloud/bigquery_storage_v1/writer.py b/google/cloud/bigquery_storage_v1/writer.py index 1431e483..9c9a4927 100644 --- a/google/cloud/bigquery_storage_v1/writer.py +++ b/google/cloud/bigquery_storage_v1/writer.py @@ -100,6 +100,17 @@ def __init__( # The threads created in ``._open()``. self._consumer = None + # The protobuf payload will be decoded as proto2 on the server side. The schema is also + # specified as proto2. Hence we must clear proto3-only features. This works since proto2 and + # proto3 are binary-compatible. + proto_descriptor = ( + self._inital_request_template.proto_rows.writer_schema.proto_descriptor + ) + for field in proto_descriptor.field: + field.ClearField("oneof_index") + field.ClearField("proto3_optional") + proto_descriptor.ClearField("oneof_decl") + @property def is_active(self) -> bool: """bool: True if this manager is actively streaming. diff --git a/tests/system/test_writer.py b/tests/system/test_writer.py index 91cb11e5..a5b19b3e 100644 --- a/tests/system/test_writer.py +++ b/tests/system/test_writer.py @@ -16,6 +16,29 @@ import pytest from google.cloud.bigquery_storage_v1 import types as gapic_types +from google.cloud.bigquery_storage_v1.writer import AppendRowsStream +import uuid + + +@pytest.fixture +def table(project_id, dataset, bq_client): + from google.cloud import bigquery + + schema = [ + bigquery.SchemaField("first_name", "STRING", mode="NULLABLE"), + bigquery.SchemaField("last_name", "STRING", mode="NULLABLE"), + bigquery.SchemaField("age", "INTEGER", mode="NULLABLE"), + ] + + unique_suffix = str(uuid.uuid4()).replace("-", "_") + table_id = "users_" + unique_suffix + table_id_full = f"{project_id}.{dataset.dataset_id}.{table_id}" + bq_table = bigquery.Table(table_id_full, schema=schema) + created_table = bq_client.create_table(bq_table) + + yield created_table + + bq_client.delete_table(created_table) @pytest.fixture(scope="session") @@ -31,3 +54,63 @@ def test_append_rows_with_invalid_stream_name_fails_fast(bqstorage_write_client) with pytest.raises(exceptions.GoogleAPICallError): bqstorage_write_client.append_rows(bad_request) + + +def test_append_rows_with_proto3(bqstorage_write_client, table): + import proto + from google.protobuf import descriptor_pb2 + + # Using Proto Plus to build proto3 + # Declare proto3 field `optional` for presence + class PersonProto(proto.Message): + first_name = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + last_name = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + age = proto.Field( + proto.INT64, + number=3, + optional=True, + ) + + person_pb = PersonProto.pb() + + stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default" + request_template = gapic_types.AppendRowsRequest() + request_template.write_stream = stream_name + + proto_schema = gapic_types.ProtoSchema() + proto_descriptor = descriptor_pb2.DescriptorProto() + person_pb.DESCRIPTOR.CopyToProto( + proto_descriptor, + ) + proto_schema.proto_descriptor = proto_descriptor + proto_data = gapic_types.AppendRowsRequest.ProtoData() + proto_data.writer_schema = proto_schema + request_template.proto_rows = proto_data + + append_rows_stream = AppendRowsStream( + bqstorage_write_client, + request_template, + ) + + request = gapic_types.AppendRowsRequest() + proto_data = gapic_types.AppendRowsRequest.ProtoData() + proto_rows = gapic_types.ProtoRows() + row = person_pb() + row.first_name = "fn" + row.last_name = "ln" + row.age = 20 + proto_rows.serialized_rows.append(row.SerializeToString()) + proto_data.rows = proto_rows + request.proto_rows = proto_data + response_future = append_rows_stream.send(request) + + assert response_future.result() + # The request should success