diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java index 8d74f2f6117a2..ead1fa67dc98d 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java @@ -129,8 +129,7 @@ public PTransform getTransform(FunctionSpec spec) { Row configRow; try { configRow = - RowCoder.of(provider.configurationSchema()) - .decode(payload.getConfigurationRow().newInput()); + RowCoder.of(configSchemaFromRequest).decode(payload.getConfigurationRow().newInput()); } catch (IOException e) { throw new RuntimeException("Error decoding payload", e); } diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java index 141d2b48b105a..d7a665eabe0f9 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java @@ -19,9 +19,9 @@ import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import com.google.auto.service.AutoService; -import java.io.IOException; import java.util.ArrayList; import java.util.List; import org.apache.beam.model.expansion.v1.ExpansionApi; @@ -32,6 +32,7 @@ import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.schemas.JavaFieldSchema; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; @@ -48,10 +49,11 @@ import org.apache.beam.sdk.transforms.InferableFunction; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p54p0.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; @@ -74,6 +76,13 @@ public class ExpansionServiceSchemaTransformProviderTest { Field.of("str1", FieldType.STRING), Field.of("str2", FieldType.STRING)); + private static final Schema TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA = + Schema.of( + Field.of("str2", FieldType.STRING), + Field.of("str1", FieldType.STRING), + Field.of("int2", FieldType.INT32), + Field.of("int1", FieldType.INT32)); + private ExpansionService expansionService = new ExpansionService(); @DefaultSchema(JavaFieldSchema.class) @@ -344,31 +353,13 @@ public void testSchemaTransformExpansion() { .withFieldValue("str2", "bbb") .build(); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - - ExternalTransforms.SchemaTransformPayload payload = - ExternalTransforms.SchemaTransformPayload.newBuilder() - .setIdentifier("dummy_id") - .setConfigurationRow(outputStream.toByteString()) - .setConfigurationSchema( - SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) - .build(); - ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder() .setComponents(pipelineProto.getComponents()) .setTransform( RunnerApi.PTransform.newBuilder() .setUniqueName(TEST_NAME) - .setSpec( - RunnerApi.FunctionSpec.newBuilder() - .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) - .setPayload(payload.toByteString())) + .setSpec(createSpec("dummy_id", configRow)) .putInputs("input1", inputPcollId)) .setNamespace(TEST_NAMESPACE) .build(); @@ -403,35 +394,18 @@ public void testSchemaTransformExpansionMultiInputMultiOutput() { .withFieldValue("str2", "bbb") .build(); - ByteStringOutputStream outputStream = new ByteStringOutputStream(); - try { - SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream); - } catch (IOException e) { - throw new RuntimeException(e); - } - - ExternalTransforms.SchemaTransformPayload payload = - ExternalTransforms.SchemaTransformPayload.newBuilder() - .setIdentifier("dummy_id_multi_input_multi_output") - .setConfigurationRow(outputStream.toByteString()) - .setConfigurationSchema( - SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true)) - .build(); - ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder() .setComponents(pipelineProto.getComponents()) .setTransform( RunnerApi.PTransform.newBuilder() .setUniqueName(TEST_NAME) - .setSpec( - RunnerApi.FunctionSpec.newBuilder() - .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) - .setPayload(payload.toByteString())) + .setSpec(createSpec("dummy_id_multi_input_multi_output", configRow)) .putInputs("input1", inputPcollIds.get(0)) .putInputs("input2", inputPcollIds.get(1))) .setNamespace(TEST_NAMESPACE) .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); RunnerApi.PTransform expandedTransform = response.getTransform(); @@ -440,4 +414,61 @@ public void testSchemaTransformExpansionMultiInputMultiOutput() { assertEquals(2, expandedTransform.getOutputsCount()); verifyLeafTransforms(response, 2); } + + @Test + public void testSchematransformEquivalentConfigSchema() throws CoderException { + Row configRow = + Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA) + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .build(); + + RunnerApi.FunctionSpec spec = createSpec("dummy_id", configRow); + + Row equivalentConfigRow = + Row.withSchema(TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA) + .withFieldValue("int1", 111) + .withFieldValue("int2", 222) + .withFieldValue("str1", "aaa") + .withFieldValue("str2", "bbb") + .build(); + + RunnerApi.FunctionSpec equivalentSpec = createSpec("dummy_id", equivalentConfigRow); + + assertNotEquals(spec.getPayload(), equivalentSpec.getPayload()); + + TestSchemaTransform transform = + (TestSchemaTransform) ExpansionServiceSchemaTransformProvider.of().getTransform(spec); + TestSchemaTransform equivalentTransform = + (TestSchemaTransform) + ExpansionServiceSchemaTransformProvider.of().getTransform(equivalentSpec); + + assertEquals(transform.int1, equivalentTransform.int1); + assertEquals(transform.int2, equivalentTransform.int2); + assertEquals(transform.str1, equivalentTransform.str1); + assertEquals(transform.str2, equivalentTransform.str2); + } + + private RunnerApi.FunctionSpec createSpec(String identifier, Row configRow) { + byte[] encodedRow; + try { + encodedRow = CoderUtils.encodeToByteArray(SchemaCoder.of(configRow.getSchema()), configRow); + } catch (CoderException e) { + throw new RuntimeException(e); + } + + ExternalTransforms.SchemaTransformPayload payload = + ExternalTransforms.SchemaTransformPayload.newBuilder() + .setIdentifier(identifier) + .setConfigurationRow(ByteString.copyFrom(encodedRow)) + .setConfigurationSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), true)) + .build(); + + return RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM)) + .setPayload(payload.toByteString()) + .build(); + } }