diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java index d9dfc2a90bd8..997b45a45729 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java @@ -40,10 +40,14 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class GenerateSequenceSchemaTransformProvider extends TypedSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(GenerateSequenceSchemaTransformProvider.class); public static final String OUTPUT_ROWS_TAG = "output"; public static final Schema OUTPUT_SCHEMA = Schema.builder().addInt64Field("value").build(); @@ -130,6 +134,18 @@ public static Builder builder() { @Nullable public abstract Rate getRate(); + @SchemaFieldDescription( + "Number of elements to generate per period. Alternative to using the 'rate' object. " + + "If set, 'period' must also be set. Takes precedence over 'rate'.") + @Nullable + public abstract Long getElementsPerPeriod(); + + @SchemaFieldDescription( + "The period in seconds for generating elements. Alternative to using the 'rate' object. " + + "If set, 'elementsPerPeriod' must also be set. Takes precedence over 'rate'.") + @Nullable + public abstract Long getPeriod(); + @AutoValue.Builder public abstract static class Builder { @@ -139,6 +155,10 @@ public abstract static class Builder { public abstract Builder setRate(Rate rate); + public abstract Builder setElementsPerPeriod(Long elementsPerPeriod); + + public abstract Builder setPeriod(Long period); + public abstract GenerateSequenceConfiguration build(); } @@ -149,8 +169,33 @@ public void validate() { if (end != null) { checkArgument(end == -1 || end >= start, "Invalid range [%s, %s)", start, end); } - Rate rate = this.getRate(); - if (rate != null) { + + Long elementsPerPeriod = getElementsPerPeriod(); + Long period = getPeriod(); + Rate rate = getRate(); + + if (elementsPerPeriod != null || period != null) { + // Ensure both are specified if one is. + if (elementsPerPeriod == null || period == null) { + throw new IllegalArgumentException( + "If either 'elementsPerPeriod' or 'period' is specified, both must be specified."); + } + // At this point, both elementsPerPeriod and period are guaranteed to be non-null. + checkArgument( + elementsPerPeriod > 0, + "Invalid 'elementsPerPeriod' specification. Expected positive value but received %s.", + elementsPerPeriod); + checkArgument( + period > 0, + "Invalid 'period' specification. Expected positive value but received %s.", + period); + if (rate != null) { + // Consider logging a warning if rate is also set, as it will be ignored. + // For now, we just prioritize elementsPerPeriod/period. + LOG.warn( + "Configuration includes both 'elementsPerPeriod'/'period' and 'rate'. 'rate' will be ignored."); + } + } else if (rate != null) { checkArgument( rate.getElements() > 0, "Invalid rate specification. Expected positive elements component but received %s.", @@ -159,6 +204,10 @@ public void validate() { Optional.ofNullable(rate.getSeconds()).orElse(1L) > 0, "Invalid rate specification. Expected positive seconds component but received %s.", rate.getSeconds()); + // Ensure seconds is present if elements is, to match the original issue's concern + checkArgument( + !(rate.getElements() != null && rate.getSeconds() == null), + "Invalid rate specification. If rate.elements is specified, rate.seconds must also be specified."); } } } @@ -177,14 +226,25 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { input.getAll().isEmpty(), "Expected no inputs but got: %s", input.getAll().keySet()); Long end = Optional.ofNullable(configuration.getEnd()).orElse(-1L); - GenerateSequenceConfiguration.Rate rate = configuration.getRate(); - GenerateSequence sequence = GenerateSequence.from(configuration.getStart()).to(end); - if (rate != null) { - sequence = - sequence.withRate( - rate.getElements(), - Duration.standardSeconds(Optional.ofNullable(rate.getSeconds()).orElse(1L))); + + Long elementsPerPeriod = configuration.getElementsPerPeriod(); + Long period = configuration.getPeriod(); + + if (elementsPerPeriod != null && period != null) { + // elementsPerPeriod and period are validated to be non-null and positive by validate() + sequence = sequence.withRate(elementsPerPeriod, Duration.standardSeconds(period)); + } else { + GenerateSequenceConfiguration.Rate rate = configuration.getRate(); + if (rate != null) { + // rate.getElements() is validated to be positive. + // rate.getSeconds() is validated to be positive if present, defaults to 1L if null. + // The additional check in validate() ensures getSeconds() is present if getElements() is. + sequence = + sequence.withRate( + rate.getElements(), + Duration.standardSeconds(Optional.ofNullable(rate.getSeconds()).orElse(1L))); + } } return PCollectionRowTuple.of( diff --git a/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py b/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py index d1c5cbfa8e9a..68a01b5bac18 100644 --- a/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py +++ b/sdks/python/apache_beam/transforms/external_transform_provider_it_test.py @@ -122,6 +122,34 @@ def test_run_generate_sequence(self): assert_that(numbers, equal_to([i for i in range(10)])) + def test_run_generate_sequence_with_elements_per_period(self): + provider = ExternalTransformProvider( + BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar")) + + # We expect this to produce 0, 1, 2, 3. + # The rate limiting (2 elements per 1 second) is primarily to ensure + # these parameters are accepted and the pipeline runs. + # Exact timing is hard to assert in an IT. + # The end parameter ensures the sequence is bounded for the test. + with beam.Pipeline() as p: + numbers = p | provider.GenerateSequence( + start=0, end=4, elements_per_period=2, + period=1) | beam.Map(lambda row: row.value) + + assert_that(numbers, equal_to([0, 1, 2, 3])) + + def test_run_generate_sequence_with_rate(self): + provider = ExternalTransformProvider( + BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar")) + + with beam.Pipeline() as p: + numbers = p | provider.GenerateSequence( + start=0, end=3, rate={ + 'elements': 1, 'seconds': 1 + }) | beam.Map(lambda row: row.value) + + assert_that(numbers, equal_to([0, 1, 2])) + @pytest.mark.xlang_wrapper_generation @unittest.skipUnless( diff --git a/sdks/standard_external_transforms.yaml b/sdks/standard_external_transforms.yaml index f5d71830145a..858fad1bd14f 100644 --- a/sdks/standard_external_transforms.yaml +++ b/sdks/standard_external_transforms.yaml @@ -19,7 +19,7 @@ # configuration in /sdks/standard_expansion_services.yaml. # Refer to gen_xlang_wrappers.py for more info. # -# Last updated on: 2025-04-24 +# Last updated on: 2025-05-30 - default_service: sdks:java:io:expansion-service:shadowJar description: 'Outputs a PCollection of Beam Rows, each containing a single INT64 @@ -34,11 +34,22 @@ destinations: python: apache_beam/io fields: + - description: Number of elements to generate per period. Alternative to using the + 'rate' object. If set, 'period' must also be set. Takes precedence over 'rate'. + name: elements_per_period + nullable: true + type: int64 - description: The maximum number to generate (exclusive). Will be an unbounded sequence if left unspecified. name: end nullable: true type: int64 + - description: The period in seconds for generating elements. Alternative to using + the 'rate' object. If set, 'elementsPerPeriod' must also be set. Takes precedence + over 'rate'. + name: period + nullable: true + type: int64 - description: Specifies the rate to generate a given number of elements per a given number of seconds. Applicable only to unbounded sequences. name: rate