diff --git a/build.gradle.kts b/build.gradle.kts index b70d8591bf83..6ebbb44ba949 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -450,16 +450,7 @@ tasks.register("pythonFormatterPreCommit") { tasks.register("python37PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py37:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py37:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py37:directRunnerIT") - dependsOn(":sdks:python:test-suites:direct:py37:hdfsIntegrationTest") - dependsOn(":sdks:python:test-suites:direct:py37:azureIntegrationTest") - dependsOn(":sdks:python:test-suites:direct:py37:mongodbioIT") dependsOn(":sdks:python:test-suites:portable:py37:postCommitPy37") - dependsOn(":sdks:python:test-suites:dataflow:py37:spannerioIT") - dependsOn(":sdks:python:test-suites:direct:py37:spannerioIT") - dependsOn(":sdks:python:test-suites:portable:py37:xlangSpannerIOIT") - dependsOn(":sdks:python:test-suites:direct:py37:inferencePostCommitIT") } tasks.register("python38PostCommit") { @@ -483,8 +474,6 @@ tasks.register("python39PostCommit") { tasks.register("python310PostCommit") { dependsOn(":sdks:python:test-suites:dataflow:py310:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py310:postCommitIT") - dependsOn(":sdks:python:test-suites:direct:py310:hdfsIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py310:postCommitPy310") } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java index 6028d8b9016e..21f368c78343 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProto.java @@ -118,7 +118,7 @@ public class BeamRowToStorageApiProto { .put( SqlTypes.DATETIME.getIdentifier(), (logicalType, value) -> - CivilTimeEncoder.encodePacked64DatetimeSeconds((LocalDateTime) value)) + CivilTimeEncoder.encodePacked64DatetimeMicros((LocalDateTime) value)) .put( SqlTypes.TIMESTAMP.getIdentifier(), (logicalType, value) -> (ChronoUnit.MICROS.between(Instant.EPOCH, (Instant) value))) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java index f0caa958df94..e76fedff328a 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryUtils.java @@ -66,6 +66,7 @@ import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -385,7 +386,7 @@ private static List toTableFieldSchema(Schema schema) { FieldType type = schemaField.getType(); TableFieldSchema field = new TableFieldSchema().setName(schemaField.getName()); - if (schemaField.getDescription() != null && !"".equals(schemaField.getDescription())) { + if (!Strings.isNullOrEmpty(schemaField.getDescription())) { field.setDescription(schemaField.getDescription()); } @@ -512,7 +513,7 @@ public static TableRow convertGenericRecordToTableRow( return BigQueryAvroUtils.convertGenericRecordToTableRow(record, tableSchema); } - /** Convert a BigQuery TableRow to a Beam Row. */ + /** Convert a Beam Row to a BigQuery TableRow. */ public static TableRow toTableRow(Row row) { TableRow output = new TableRow(); for (int i = 0; i < row.getFieldCount(); i++) { @@ -686,6 +687,14 @@ public static Row toBeamRow(Schema rowSchema, TableSchema bqSchema, TableRow jso if (JSON_VALUE_PARSERS.containsKey(fieldType.getTypeName())) { return JSON_VALUE_PARSERS.get(fieldType.getTypeName()).apply(jsonBQString); } else if (fieldType.isLogicalType(SqlTypes.DATETIME.getIdentifier())) { + // Handle if datetime value is in micros + try { + Long value = Long.parseLong(jsonBQString); + return CivilTimeEncoder.decodePacked64DatetimeMicrosAsJavaTime(value); + } catch (NumberFormatException e) { + // This means value is not represented by a number, so we swallow and handle it as a + // String + } return LocalDateTime.parse(jsonBQString, BIGQUERY_DATETIME_FORMATTER); } else if (fieldType.isLogicalType(SqlTypes.DATE.getIdentifier())) { return LocalDate.parse(jsonBQString); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java index 5f7851bba519..205080fba895 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProvider.java @@ -22,6 +22,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -51,7 +52,7 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; -import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -74,7 +75,8 @@ public class BigQueryStorageWriteApiSchemaTransformProvider extends TypedSchemaTransformProvider { private static final Duration DEFAULT_TRIGGERING_FREQUENCY = Duration.standardSeconds(5); private static final String INPUT_ROWS_TAG = "input"; - private static final String OUTPUT_ERRORS_TAG = "errors"; + private static final String FAILED_ROWS_TAG = "failed_rows"; + private static final String FAILED_ROWS_WITH_ERRORS_TAG = "failed_rows_with_errors"; @Override protected Class configurationClass() { @@ -99,7 +101,7 @@ public List inputCollectionNames() { @Override public List outputCollectionNames() { - return Collections.singletonList(OUTPUT_ERRORS_TAG); + return Arrays.asList(FAILED_ROWS_TAG, FAILED_ROWS_WITH_ERRORS_TAG); } /** Configuration for writing to BigQuery with Storage Write API. */ @@ -130,17 +132,19 @@ public void validate() { // validate create and write dispositions if (!Strings.isNullOrEmpty(this.getCreateDisposition())) { - checkArgument( - CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()) != null, + checkNotNull( + CREATE_DISPOSITIONS.get(this.getCreateDisposition().toUpperCase()), invalidConfigMessage - + "Invalid create disposition was specified. Available dispositions are: ", + + "Invalid create disposition (%s) was specified. Available dispositions are: %s", + this.getCreateDisposition(), CREATE_DISPOSITIONS.keySet()); } if (!Strings.isNullOrEmpty(this.getWriteDisposition())) { checkNotNull( WRITE_DISPOSITIONS.get(this.getWriteDisposition().toUpperCase()), invalidConfigMessage - + "Invalid write disposition was specified. Available dispositions are: ", + + "Invalid write disposition (%s) was specified. Available dispositions are: %s", + this.getWriteDisposition(), WRITE_DISPOSITIONS.keySet()); } } @@ -229,7 +233,6 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { PCollection inputRows = input.get(INPUT_ROWS_TAG); BigQueryIO.Write write = createStorageWriteApiTransform(); - if (inputRows.isBounded() == IsBounded.UNBOUNDED) { Long triggeringFrequency = configuration.getTriggeringFrequencySeconds(); write = @@ -240,30 +243,45 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { ? DEFAULT_TRIGGERING_FREQUENCY : Duration.standardSeconds(triggeringFrequency)); } - WriteResult result = inputRows.apply(write); + Schema rowSchema = inputRows.getSchema(); Schema errorSchema = Schema.of( - Field.of("failed_row", FieldType.STRING), + Field.of("failed_row", FieldType.row(rowSchema)), Field.of("error_message", FieldType.STRING)); - // Errors consisting of failed rows along with their error message - PCollection errorRows = + // Failed rows + PCollection failedRows = + result + .getFailedStorageApiInserts() + .apply( + "Construct Failed Rows", + MapElements.into(TypeDescriptors.rows()) + .via( + (storageError) -> + BigQueryUtils.toBeamRow(rowSchema, storageError.getRow()))) + .setRowSchema(rowSchema); + + // Failed rows along with their corresponding error messages + PCollection failedRowsWithErrors = result .getFailedStorageApiInserts() .apply( - "Extract Errors", - MapElements.into(TypeDescriptor.of(Row.class)) + "Construct Failed Rows and Errors", + MapElements.into(TypeDescriptors.rows()) .via( (storageError) -> Row.withSchema(errorSchema) .withFieldValue("error_message", storageError.getErrorMessage()) - .withFieldValue("failed_row", storageError.getRow().toString()) + .withFieldValue( + "failed_row", + BigQueryUtils.toBeamRow(rowSchema, storageError.getRow())) .build())) .setRowSchema(errorSchema); - return PCollectionRowTuple.of(OUTPUT_ERRORS_TAG, errorRows); + return PCollectionRowTuple.of(FAILED_ROWS_TAG, failedRows) + .and(FAILED_ROWS_WITH_ERRORS_TAG, failedRowsWithErrors); } BigQueryIO.Write createStorageWriteApiTransform() { @@ -283,13 +301,13 @@ BigQueryIO.Write createStorageWriteApiTransform() { if (!Strings.isNullOrEmpty(configuration.getCreateDisposition())) { CreateDisposition createDisposition = BigQueryStorageWriteApiSchemaTransformConfiguration.CREATE_DISPOSITIONS.get( - configuration.getCreateDisposition()); + configuration.getCreateDisposition().toUpperCase()); write = write.withCreateDisposition(createDisposition); } if (!Strings.isNullOrEmpty(configuration.getWriteDisposition())) { WriteDisposition writeDisposition = BigQueryStorageWriteApiSchemaTransformConfiguration.WRITE_DISPOSITIONS.get( - configuration.getWriteDisposition()); + configuration.getWriteDisposition().toUpperCase()); write = write.withWriteDisposition(writeDisposition); } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java index ca82dc9dae6b..c8b8a3cb6cb1 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BeamRowToStorageApiProtoTest.java @@ -258,7 +258,7 @@ public class BeamRowToStorageApiProtoTest { BASE_ROW.getLogicalTypeValue("sqlTimeValue", LocalTime.class))) .put( "sqldatetimevalue", - CivilTimeEncoder.encodePacked64DatetimeSeconds( + CivilTimeEncoder.encodePacked64DatetimeMicros( BASE_ROW.getLogicalTypeValue("sqlDatetimeValue", LocalDateTime.class))) .put( "sqltimestampvalue", diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java index c8e733c8458f..af0e3c243eab 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/providers/BigQueryStorageWriteApiSchemaTransformProviderTest.java @@ -17,13 +17,17 @@ */ package org.apache.beam.sdk.io.gcp.bigquery.providers; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import com.google.api.services.bigquery.model.TableRow; +import java.io.Serializable; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Function; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiPCollectionRowTupleTransform; import org.apache.beam.sdk.io.gcp.bigquery.providers.BigQueryStorageWriteApiSchemaTransformProvider.BigQueryStorageWriteApiSchemaTransformConfiguration; @@ -34,6 +38,7 @@ import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; +import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; @@ -70,12 +75,12 @@ public class BigQueryStorageWriteApiSchemaTransformProviderTest { Row.withSchema(SCHEMA) .withFieldValue("name", "b") .withFieldValue("number", 2L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00")) + .withFieldValue("dt", LocalDateTime.parse("2000-01-02T00:00:00.123")) .build(), Row.withSchema(SCHEMA) .withFieldValue("name", "c") .withFieldValue("number", 3L) - .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00")) + .withFieldValue("dt", LocalDateTime.parse("2000-01-03T00:00:00.123456")) .build()); @Rule public final transient TestPipeline p = TestPipeline.create(); @@ -107,7 +112,7 @@ public void testInvalidConfig() { } public PCollectionRowTuple runWithConfig( - BigQueryStorageWriteApiSchemaTransformConfiguration config) { + BigQueryStorageWriteApiSchemaTransformConfiguration config, List rows) { BigQueryStorageWriteApiSchemaTransformProvider provider = new BigQueryStorageWriteApiSchemaTransformProvider(); @@ -118,25 +123,82 @@ public PCollectionRowTuple runWithConfig( writeRowTupleTransform.setBigQueryServices(fakeBigQueryServices); String tag = provider.inputCollectionNames().get(0); - PCollection rows = p.apply(Create.of(ROWS).withRowSchema(SCHEMA)); + PCollection rowPc = p.apply(Create.of(rows).withRowSchema(SCHEMA)); - PCollectionRowTuple input = PCollectionRowTuple.of(tag, rows); + PCollectionRowTuple input = PCollectionRowTuple.of(tag, rowPc); PCollectionRowTuple result = input.apply(writeRowTupleTransform); return result; } + public Boolean rowsEquals(List expectedRows, List actualRows) { + if (expectedRows.size() != actualRows.size()) { + return false; + } + for (int i = 0; i < expectedRows.size(); i++) { + TableRow actualRow = actualRows.get(i); + Row expectedRow = expectedRows.get(Integer.parseInt(actualRow.get("number").toString()) - 1); + + if (!expectedRow.getValue("name").equals(actualRow.get("name")) + || !expectedRow + .getValue("number") + .equals(Long.parseLong(actualRow.get("number").toString()))) { + return false; + } + } + return true; + } + @Test public void testSimpleWrite() throws Exception { String tableSpec = "project:dataset.simple_write"; BigQueryStorageWriteApiSchemaTransformConfiguration config = BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); - runWithConfig(config); + runWithConfig(config, ROWS); + p.run().waitUntilFinish(); + + assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); + assertTrue( + rowsEquals(ROWS, fakeDatasetService.getAllRows("project", "dataset", "simple_write"))); + } + + @Test + public void testFailedRows() throws Exception { + String tableSpec = "project:dataset.write_with_fail"; + BigQueryStorageWriteApiSchemaTransformConfiguration config = + BigQueryStorageWriteApiSchemaTransformConfiguration.builder().setTable(tableSpec).build(); + + String failValue = "fail_me"; + + List expectedSuccessfulRows = new ArrayList<>(ROWS); + List expectedFailedRows = new ArrayList<>(); + for (long l = 1L; l <= 3L; l++) { + expectedFailedRows.add( + Row.withSchema(SCHEMA) + .withFieldValue("name", failValue) + .withFieldValue("number", l) + .withFieldValue("dt", LocalDateTime.parse("2020-01-01T00:00:00.09")) + .build()); + } + + List totalRows = new ArrayList<>(expectedSuccessfulRows); + totalRows.addAll(expectedFailedRows); + + Function shouldFailRow = + (Function & Serializable) tr -> tr.get("name").equals(failValue); + fakeDatasetService.setShouldFailRow(shouldFailRow); + + PCollectionRowTuple result = runWithConfig(config, totalRows); + PCollection failedRows = result.get("failed_rows"); + + PAssert.that(failedRows).containsInAnyOrder(expectedFailedRows); p.run().waitUntilFinish(); assertNotNull(fakeDatasetService.getTable(BigQueryHelpers.parseTableSpec(tableSpec))); - assertEquals( - ROWS.size(), fakeDatasetService.getAllRows("project", "dataset", "simple_write").size()); + assertTrue( + rowsEquals( + expectedSuccessfulRows, + fakeDatasetService.getAllRows("project", "dataset", "write_with_fail"))); } } diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 7233326ce0c2..8cd563a403f5 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -397,6 +397,8 @@ def chain_after(result): from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform from apache_beam.transforms.display import DisplayDataItem +from apache_beam.transforms.external import BeamJarExpansionService +from apache_beam.transforms.external import SchemaAwareExternalTransform from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX from apache_beam.transforms.sideinputs import get_sideinput_index from apache_beam.transforms.util import ReshufflePerKey @@ -432,6 +434,7 @@ def chain_after(result): 'BigQueryQueryPriority', 'WriteToBigQuery', 'WriteResult', + 'StorageWriteToBigQuery', 'ReadFromBigQuery', 'ReadFromBigQueryRequest', 'ReadAllFromBigQuery', @@ -2300,6 +2303,107 @@ def __getitem__(self, key): return self.attributes[key].__get__(self, WriteResult) +def _default_io_expansion_service(append_args=None): + return BeamJarExpansionService( + 'sdks:java:io:google-cloud-platform:expansion-service:build', + append_args=append_args) + + +class StorageWriteToBigQuery(PTransform): + """Writes data to BigQuery using Storage API. + + Receives a PCollection of beam.Row() elements and writes the elements + to BigQuery. Returns a dead-letter queue of errors and failed rows + represented as a dictionary of PCollections. + + Example:: + with beam.Pipeline() as p: + items = [] + for i in range(10): + items.append(beam.Row(id=i)) + + result = (p + | 'Create items' >> beam.Create(items) + | 'Write data' >> StorageWriteToBigQuery( + table="project:dataset.table")) + _ = (result['failed_rows_with_errors'] + | 'Format errors' >> beam.Map( + lambda e: "failed row id: %s, error: %s" % + (e.failed_row.id, e.error_message)) + | 'Write errors' >> beam.io.WriteToText('./output'))) + """ + URN = "beam:schematransform:org.apache.beam:bigquery_storage_write:v1" + + def __init__( + self, + table, + create_disposition="", + write_disposition="", + triggering_frequency=0, + use_at_least_once=False, + expansion_service=None): + """Initialize a StorageWriteToBigQuery transform. + + :param table: a fully-qualified table ID specified as + ``'PROJECT:DATASET.TABLE'`` + :param create_disposition: a string specifying the strategy to + take when the table doesn't exist. Possible values are: + + * ``'CREATE_IF_NEEDED'``: create if does not exist. + * ``'CREATE_NEVER'``: fail the write if does not exist. + + :param write_disposition: a string specifying the strategy to take + when the table already contains data. Possible values are: + + * ``'WRITE_TRUNCATE'``: delete existing rows. + * ``'WRITE_APPEND'``: add to existing rows. + * ``'WRITE_EMPTY'``: fail the write if table not empty. + + :param triggering_frequency: the time in seconds between write + commits. Should only be specified for streaming pipelines. Defaults + to 5 seconds. + :param use_at_least_once: use at-least-once semantics. Is cheaper + and provides lower latency, but will potentially duplicate records. + :param expansion_service: the address (host:port) of the expansion service + """ + super().__init__() + self._table = table + self._create_disposition = create_disposition + self._write_disposition = write_disposition + self._triggering_frequency = triggering_frequency + self._use_at_least_once = use_at_least_once + self._expansion_service = ( + expansion_service or _default_io_expansion_service()) + self.schematransform_config = SchemaAwareExternalTransform.discover_config( + self._expansion_service, self.URN) + + def expand(self, input): + opts = input.pipeline.options.view_as(StandardOptions) + # TODO(https://github.com/apache/beam/issues/21307): Add support for + # OnWindowExpiration to more runners. Storage Write API requires + # `beam:requirement:pardo:on_window_expiration:v1` when unbounded + available_runners = ['DataflowRunner', 'TestDataflowRunner'] + if not input.is_bounded and opts.runner not in available_runners: + raise NotImplementedError( + "Storage API Streaming Writes via xlang is not yet available for %s." + " Available runners are %s", + opts.runner, + available_runners) + + external_storage_write = SchemaAwareExternalTransform( + self.schematransform_config.identifier, + expansion_service=self._expansion_service, + table=self._table, + createDisposition=self._create_disposition, + writeDisposition=self._write_disposition, + triggeringFrequencySeconds=self._triggering_frequency, + useAtLeastOnceSemantics=self._use_at_least_once) + + input_tag = self.schematransform_config.inputs[0] + + return {input_tag: input} | external_storage_write + + class ReadFromBigQuery(PTransform): """Read data from BigQuery. diff --git a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py index a307e06ac5b8..84408f0daf45 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_write_it_test.py @@ -23,6 +23,7 @@ import base64 import datetime import logging +import os import secrets import time import unittest @@ -32,6 +33,7 @@ import mock import pytest import pytz +from hamcrest.core import assert_that as hamcrest_assert from parameterized import param from parameterized import parameterized @@ -44,6 +46,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.utils.timestamp import Timestamp # Protect against environments where bigquery library is not available. # pylint: disable=wrong-import-order, wrong-import-position @@ -543,6 +546,105 @@ def test_big_query_write_temp_table_append_schema_update(self, file_format): temp_file_format=file_format)) +class BigQueryXlangStorageWriteIT(unittest.TestCase): + BIGQUERY_DATASET = 'python_xlang_write_' + + ELEMENTS = [ + # (int, float, string, timestamp, bool, bytes) + ( + 1, + 0.1, + 'a', + Timestamp(seconds=100, micros=10), + False, + bytes('a', 'utf-8')), + ( + 2, + 0.2, + 'b', + Timestamp(seconds=200, micros=20), + True, + bytes('b', 'utf-8')), + ( + 3, + 0.3, + 'c', + Timestamp(seconds=300, micros=30), + False, + bytes('c', 'utf-8')), + ( + 4, + 0.4, + 'd', + Timestamp(seconds=400, micros=40), + True, + bytes('d', 'utf-8')), + ] + + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.project = self.test_pipeline.get_option('project') + + self.bigquery_client = BigQueryWrapper() + self.dataset_id = '%s%s%s' % ( + self.BIGQUERY_DATASET, str(int(time.time())), secrets.token_hex(3)) + self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", self.dataset_id, self.project) + + self.expansion_service = ('localhost:%s' % os.environ.get('EXPANSION_PORT')) + self.row_elements = [ + beam.Row( + my_int=e[0], + my_float=e[1], + my_string=e[2], + my_timestamp=e[3], + my_bool=e[4], + my_bytes=e[5]) for e in self.ELEMENTS + ] + + # BigQuery matcher query returns a datetime.datetime object + self.expected_elements = [( + e[:3] + + (e[3].to_utc_datetime().replace(tzinfo=datetime.timezone.utc), ) + + e[4:]) for e in self.ELEMENTS] + + def tearDown(self): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=self.project, datasetId=self.dataset_id, deleteContents=True) + try: + _LOGGER.info( + "Deleting dataset %s in project %s", self.dataset_id, self.project) + self.bigquery_client.client.datasets.Delete(request) + except HttpError: + _LOGGER.debug( + 'Failed to clean up dataset %s in project %s', + self.dataset_id, + self.project) + + @pytest.mark.it_postcommit + @pytest.mark.bq_xlang + @pytest.mark.uses_python_expansion_service + @pytest.mark.uses_java_expansion_service + def test_xlang_storage_write(self): + table_id = '{}:{}.python_xlang_storage_write'.format( + self.project, self.dataset_id) + + bq_matcher = BigqueryFullResultMatcher( + project=self.project, + query="SELECT * FROM %s" % + '{}.python_xlang_storage_write'.format(self.dataset_id), + data=self.expected_elements) + + with beam.Pipeline() as p: + _ = ( + p + | beam.Create(self.row_elements) + | beam.io.StorageWriteToBigQuery( + table=table_id, expansion_service=self.expansion_service)) + hamcrest_assert(p, bq_matcher) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 1c4a6dd05197..b75e4c794625 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -367,6 +367,36 @@ def discover(expansion_service): inputs=proto_config.input_pcollection_names, outputs=proto_config.output_pcollection_names) + @staticmethod + def discover_config(expansion_service, name): + """Discover one SchemaTransform by name in the given expansion service. + + :return: one SchemaTransformConfig that represents the discovered + SchemaTransform + + :raises: + ValueError: if more than one SchemaTransform is discovered, or if none + are discovered + """ + + schematransforms = SchemaAwareExternalTransform.discover(expansion_service) + matched = [] + + for st in schematransforms: + if name in st.identifier: + matched.append(st) + + if not matched: + raise ValueError( + "Did not discover any SchemaTransforms resembling the name '%s'" % + name) + elif len(matched) > 1: + raise ValueError( + "Found multiple SchemaTransforms with the name '%s':\n%s\n" % + (name, [st.identifier for st in matched])) + + return matched[0] + class JavaExternalTransform(ptransform.PTransform): """A proxy for Java-implemented external transforms. diff --git a/sdks/python/apache_beam/transforms/external_test.py b/sdks/python/apache_beam/transforms/external_test.py index f38876367c39..83ca0a609a78 100644 --- a/sdks/python/apache_beam/transforms/external_test.py +++ b/sdks/python/apache_beam/transforms/external_test.py @@ -32,7 +32,9 @@ from apache_beam import Pipeline from apache_beam.coders import RowCoder from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.portability.api import beam_expansion_api_pb2 from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.portability.api import schema_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.portability import expansion_service from apache_beam.runners.portability.expansion_service_test import FibTransform @@ -475,6 +477,47 @@ def test_build_payload(self): self.assertEqual(456, schema_transform_config.object_field.int_sub_field) +class SchemaAwareExternalTransformTest(unittest.TestCase): + class MockService: + # define context manager enter and exit functions + def __enter__(self): + return self + + def __exit__(self, unusued1, unused2, unused3): + pass + + def DiscoverSchemaTransform(self, unused_request=None): + test_config = beam_expansion_api_pb2.SchemaTransformConfig( + config_schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name="test_field", + type=schema_pb2.FieldType(atomic_type="STRING")) + ], + id="test-id"), + input_pcollection_names=["input"], + output_pcollection_names=["output"]) + return beam_expansion_api_pb2.DiscoverSchemaTransformResponse( + schema_transform_configs={"test_schematransform": test_config}) + + @mock.patch("apache_beam.transforms.external.ExternalTransform.service") + def test_discover_one_config(self, mock_service): + _mock = self.MockService() + mock_service.return_value = _mock + config = beam.SchemaAwareExternalTransform.discover_config( + "test_service", name="test_schematransform") + self.assertEqual(config.outputs[0], "output") + self.assertEqual(config.inputs[0], "input") + self.assertEqual(config.identifier, "test_schematransform") + + @mock.patch("apache_beam.transforms.external.ExternalTransform.service") + def test_discover_one_config_fails_with_no_configs_found(self, mock_service): + mock_service.return_value = self.MockService() + with self.assertRaises(ValueError): + beam.SchemaAwareExternalTransform.discover_config( + "test_service", name="non_existent") + + class JavaClassLookupPayloadBuilderTest(unittest.TestCase): def _verify_row(self, schema, row_payload, expected_values): row = RowCoder(schema).decode(row_payload) diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index a879421e9394..cdf05bfa0454 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -105,7 +105,8 @@ task postCommitIT { "test_opts": testOpts, "sdk_location": files(configurations.distTarBall.files).singleFile, "suite": "postCommitIT-df${pythonVersionSuffix}", - "collect": "it_postcommit" +// "collect": "it_postcommit", + "collect": "bq_xlang" ] def cmdArgs = mapToArgString(argMap) exec { diff --git a/sdks/python/test-suites/portable/common.gradle b/sdks/python/test-suites/portable/common.gradle index c87696793f0f..01eba033bd34 100644 --- a/sdks/python/test-suites/portable/common.gradle +++ b/sdks/python/test-suites/portable/common.gradle @@ -296,11 +296,12 @@ project.tasks.register("postCommitPy${pythonVersionSuffix}IT") { doLast { def tests = [ - "apache_beam/io/gcp/bigquery_read_it_test.py", - "apache_beam/io/external/xlang_jdbcio_it_test.py", - "apache_beam/io/external/xlang_kafkaio_it_test.py", - "apache_beam/io/external/xlang_kinesisio_it_test.py", - "apache_beam/io/external/xlang_debeziumio_it_test.py", +// "apache_beam/io/gcp/bigquery_read_it_test.py", +// "apache_beam/io/external/xlang_jdbcio_it_test.py", +// "apache_beam/io/external/xlang_kafkaio_it_test.py", +// "apache_beam/io/external/xlang_kinesisio_it_test.py", +// "apache_beam/io/external/xlang_debeziumio_it_test.py", + "apache_beam/io/gcp/bigquery_write_it_test.py::BigQueryXlangStorageWriteIT", ] def testOpts = ["${tests.join(' ')}"] + ["--log-cli-level=INFO"] def pipelineOpts = [