diff --git a/sdks/java/io/kafka/build.gradle b/sdks/java/io/kafka/build.gradle index e300999063914..e6df08d41199d 100644 --- a/sdks/java/io/kafka/build.gradle +++ b/sdks/java/io/kafka/build.gradle @@ -72,6 +72,7 @@ dependencies { implementation library.java.jackson_annotations implementation library.java.jackson_databind implementation "org.springframework:spring-expression:5.3.27" + implementation group: 'com.google.cloud.hosted.kafka', name: 'managed-kafka-auth-login-handler', version: '1.0.2' implementation ("io.confluent:kafka-avro-serializer:${confluentVersion}") { // zookeeper depends on "spotbugs-annotations:3.1.9" which clashes with current // "spotbugs-annotations:3.1.12" used in Beam. Not required. diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 0f28edf19dd81..cb7b3020c66a5 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -109,6 +109,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Comparators; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.kafka.clients.CommonClientConfigs; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.KafkaConsumer; @@ -118,6 +119,7 @@ import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.Serializer; @@ -1453,6 +1455,24 @@ public Read withConsumerPollingTimeout(long duration) { return toBuilder().setConsumerPollingTimeout(duration).build(); } + /** + * Creates and sets the Application Default Credentials for a Kafka consumer. This allows the + * consumer to be authenticated with a Google Kafka Server using OAuth. + */ + public Read withGCPApplicationDefaultCredentials() { + + return withConsumerConfigUpdates( + ImmutableMap.of( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + "SASL_SSL", + SaslConfigs.SASL_MECHANISM, + "OAUTHBEARER", + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, + "com.google.cloud.hosted.kafka.auth.GcpLoginCallbackHandler", + SaslConfigs.SASL_JAAS_CONFIG, + "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;")); + } + /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */ public PTransform>> withoutMetadata() { return new TypedWithoutMetadata<>(this); @@ -3362,6 +3382,23 @@ public Write withBadRecordErrorHandler(ErrorHandler badRecor getWriteRecordsTransform().withBadRecordErrorHandler(badRecordErrorHandler)); } + /** + * Creates and sets the Application Default Credentials for a Kafka producer. This allows the + * consumer to be authenticated with a Google Kafka Server using OAuth. + */ + public Write withGCPApplicationDefaultCredentials() { + return withProducerConfigUpdates( + ImmutableMap.of( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + "SASL_SSL", + SaslConfigs.SASL_MECHANISM, + "OAUTHBEARER", + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, + "com.google.cloud.hosted.kafka.auth.GcpLoginCallbackHandler", + SaslConfigs.SASL_JAAS_CONFIG, + "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;")); + } + @Override public PDone expand(PCollection> input) { final String topic = Preconditions.checkStateNotNull(getTopic(), "withTopic() is required"); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java index fba81c51130df..c6edaf7761b58 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java @@ -103,6 +103,7 @@ import org.joda.time.Duration; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -815,6 +816,62 @@ public void testWatermarkUpdateWithSparseMessages() throws IOException, Interrup } } + @Ignore( + "Test is ignored until GMK is utilized as part of this test suite (https://github.com/apache/beam/issues/32721).") + @Test + public void testReadAndWriteFromKafkaIOWithGCPApplicationDefaultCredentials() throws IOException { + AdminClient client = + AdminClient.create( + ImmutableMap.of("bootstrap.servers", options.getKafkaBootstrapServerAddresses())); + + String topicName = "TestApplicationDefaultCreds-" + UUID.randomUUID(); + Map records = new HashMap<>(); + for (int i = 0; i < 5; i++) { + records.put(i, String.valueOf(i)); + } + + try { + client.createTopics(ImmutableSet.of(new NewTopic(topicName, 1, (short) 1))); + + writePipeline + .apply("Generate Write Elements", Create.of(records)) + .apply( + "Write to Kafka", + KafkaIO.write() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withTopic(topicName) + .withKeySerializer(IntegerSerializer.class) + .withValueSerializer(StringSerializer.class) + .withGCPApplicationDefaultCredentials()); + + writePipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + + client.createPartitions(ImmutableMap.of(topicName, NewPartitions.increaseTo(3))); + + sdfReadPipeline.apply( + "Read from Kafka", + KafkaIO.read() + .withBootstrapServers(options.getKafkaBootstrapServerAddresses()) + .withConsumerConfigUpdates(ImmutableMap.of("auto.offset.reset", "earliest")) + .withTopic(topicName) + .withKeyDeserializer(IntegerDeserializer.class) + .withValueDeserializer(StringDeserializer.class) + .withGCPApplicationDefaultCredentials() + .withoutMetadata()); + + PipelineResult readResult = sdfReadPipeline.run(); + + // Only waiting 5 seconds here because we don't expect any processing at this point + PipelineResult.State readState = readResult.waitUntilFinish(Duration.standardSeconds(5)); + + cancelIfTimeouted(readResult, readState); + // Fail the test if pipeline failed. + assertNotEquals(readState, PipelineResult.State.FAILED); + } finally { + client.deleteTopics(ImmutableSet.of(topicName)); + } + } + private static class KeyByPartition extends DoFn, KV>> {