From e90a7cdd27b62851e0297773854d003af68aa610 Mon Sep 17 00:00:00 2001 From: macasado86 Date: Tue, 11 Jun 2024 09:27:51 +0200 Subject: [PATCH] feat: add oauthbearer configuration and refresh token callback --- CONFIGURATION.md | 7 ++ c_src/erlkaf_consumer.cc | 127 +++++++++++++++++++++++++++++ c_src/erlkaf_consumer.h | 2 + c_src/erlkaf_nif.cc | 8 +- c_src/erlkaf_nif.h | 1 + c_src/erlkaf_producer.cc | 129 ++++++++++++++++++++++++++++++ c_src/erlkaf_producer.h | 2 + src/erlkaf_config.erl | 14 ++++ src/erlkaf_consumer_callbacks.erl | 6 +- src/erlkaf_consumer_group.erl | 23 +++++- src/erlkaf_nif.erl | 18 ++++- src/erlkaf_producer.erl | 24 +++++- src/erlkaf_utils.erl | 16 +++- 13 files changed, 366 insertions(+), 11 deletions(-) diff --git a/CONFIGURATION.md b/CONFIGURATION.md index 15e0840..de0ae0c 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -67,6 +67,13 @@ sasl_username | * | | sasl_password | * | | | SASL password for use with the PLAIN and SASL-SCRAM-.. mechanism sasl_oauthbearer_config | * | | | SASL/OAUTHBEARER configuration. The format is implementation-dependent and must be parsed accordingly. The default unsecured token implementation (see https://tools.ietf.org/html/rfc7515#appendix-A.5) recognizes space-separated name=value pairs with valid names including principalClaimName, principal, scopeClaimName, scope, and lifeSeconds. The default value for principalClaimName is "sub", the default value for scopeClaimName is "scope", and the default value for lifeSeconds is 3600. The scope value is CSV format with the default value being no/empty scope. For example: `principalClaimName=azp principal=admin scopeClaimName=roles scope=role1,role2 lifeSeconds=600`. In addition, SASL extensions can be communicated to the broker via `extension_NAME=value`. For example: `principal=admin extension_traceId=123`. enable_sasl_oauthbearer_unsecure_jwt | * | true, false | false | Enable the builtin unsecure JWT OAUTHBEARER token handler if no oauthbearer_refresh_cb has been set. This builtin handler should only be used for development or testing, and not in production. +oauthbearer_token_refresh_callback | * | module or fun/2 | undefined | A callback to implement SASL/OAUTHBEARER token refresh. +sasl_oauthbearer_method | * | default, oidc | default | Set to "default" or "oidc" to control which login method to be used. If set to "oidc", the following properties must also be be specified: `sasl_oauthbearer_client_id`, `sasl_oauthbearer_client_secret`, and `sasl_oauthbearer_token_endpoint_url`. +sasl_oauthbearer_client_id | * | | | Public identifier for the application. Must be unique across all clients that the authorization server handles. Only used when `sasl_oauthbearer_method` is set to "oidc". +sasl_oauthbearer_client_secret | * | | | Client secret only known to the application and the authorization server. This should be a sufficiently random string that is not guessable. Only used when `sasl_oauthbearer_method` is set to "oidc". +sasl_oauthbearer_scope | * | | | Client use this to specify the scope of the access request to the broker. Only used when `sasl_oauthbearer_method` is set to "oidc". +sasl.oauthbearer.extensions | * | | | Allow additional information to be provided to the broker. Comma-separated list of key=value pairs. E.g., "supportFeatureX=true,organizationId=sales-emea".Only used when `sasl_oauthbearer_method` is set to "oidc". +sasl_oauthbearer_token_endpoint_url | * | | | OAuth/OIDC issuer token endpoint HTTP(S) URI used to retrieve token. Only used when `sasl_oauthbearer_method` is set to "oidc". plugin_library_paths | * | | undefined| Path where `librdkafka` plugins are located group_instance_id | C | | | Enable static group membership. Static group members are able to leave and rejoin a group within the configured `session.timeout.ms` without prompting a group rebalance. This should be used in combination with a larger `session.timeout.ms` to avoid group rebalances caused by transient unavailability (e.g. process restarts). Requires broker version >= 2.3.0. partition_assignment_strategy | C | | range, roundrobin | Name of partition assignment strategy to use when elected group leader assigns partitions to group members diff --git a/c_src/erlkaf_consumer.cc b/c_src/erlkaf_consumer.cc index 4ca0108..1e257e8 100644 --- a/c_src/erlkaf_consumer.cc +++ b/c_src/erlkaf_consumer.cc @@ -15,6 +15,8 @@ #include #include #include +#include +#include namespace { @@ -165,6 +167,26 @@ int stats_callback(rd_kafka_t *rk, char *json, size_t json_len, void *opaque) return 0; } +void oauthbearer_token_refresh_callback(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) +{ + UNUSED(rk); + + enif_consumer* consumer = static_cast(opaque); + ErlNifEnv* env = enif_alloc_env(); + + if (oauthbearer_config == NULL) + { + enif_send(NULL, &consumer->owner, env, enif_make_tuple2(env, ATOMS.atomOauthbearerTokenRefresh, ATOMS.atomUndefined)); + } + else + { + ERL_NIF_TERM config = make_binary(env, oauthbearer_config, strlen(oauthbearer_config)); + enif_send(NULL, &consumer->owner, env, enif_make_tuple2(env, ATOMS.atomOauthbearerTokenRefresh, config)); + } + + enif_free_env(env); +} + rd_kafka_topic_partition_list_t* topic_subscribe(ErlNifEnv* env, ERL_NIF_TERM list) { uint32_t length; @@ -270,6 +292,8 @@ ERL_NIF_TERM enif_consumer_new(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv rd_kafka_conf_set_log_cb(client_conf.get(), logger_callback); rd_kafka_conf_set_rebalance_cb(client_conf.get(), rebalance_cb); rd_kafka_conf_set_stats_cb(client_conf.get(), stats_callback); + rd_kafka_conf_set_oauthbearer_token_refresh_cb(client_conf.get(), oauthbearer_token_refresh_callback); + rd_kafka_conf_enable_sasl_queue(client_conf.get(), 1); scoped_ptr(rk, rd_kafka_t, rd_kafka_new(RD_KAFKA_CONSUMER, client_conf.get(), errstr, sizeof(errstr)), rd_kafka_destroy); @@ -299,6 +323,8 @@ ERL_NIF_TERM enif_consumer_new(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv data->notifier_->watch(consumer->kf, true); + rd_kafka_sasl_background_callbacks_enable(consumer->kf); + ERL_NIF_TERM term = enif_make_resource(env, consumer.get()); return enif_make_tuple2(env, ATOMS.atomOk, term); } @@ -441,3 +467,104 @@ ERL_NIF_TERM enif_consumer_cleanup(ErlNifEnv* env, int argc, const ERL_NIF_TERM return ATOMS.atomOk; } +char** split_consumer_extensions(const std::string extensions_str, size_t* length) +{ + std::stringstream extensions_stream(extensions_str); + std::string extension_tmp; + std::string kv_tmp; + std::vector extensions_vector; + + while (getline(extensions_stream, extension_tmp, ',')) + { + std::stringstream kv_stream(extension_tmp); + while (getline(kv_stream, kv_tmp, '=')) + extensions_vector.push_back(kv_tmp); + } + + *length = extensions_vector.size(); + char ** extensions = new char*[*length]; + + for(size_t i = 0; i < *length; ++i) + { + extensions[i] = new char[extensions_vector[i].size() + 1]; + strcpy(extensions[i], extensions_vector[i].c_str()); + } + + return extensions; +} + +void free_consumer_extensions(char** extensions, size_t length) +{ + if (extensions != nullptr) + { + for (size_t i = 0; i < length; ++i) + if (extensions[i] != nullptr) + delete[] extensions[i]; + + delete[] extensions; + } +} + +ERL_NIF_TERM enif_consumer_oauthbearer_set_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + UNUSED(argc); + + std::string token; + long lifetime; + std::string principal; + std::string extensions_str; + + erlkaf_data* data = static_cast(enif_priv_data(env)); + enif_consumer* consumer; + + if(!enif_get_resource(env, argv[0], data->res_consumer, reinterpret_cast(&consumer))) + return make_badarg(env); + + if(!get_string(env, argv[1], &token)) + return make_badarg(env); + + if(!enif_get_long(env, argv[2], &lifetime)) + return make_badarg(env); + + if(!get_string(env, argv[3], &principal)) + return make_badarg(env); + + if(!get_string(env, argv[4], &extensions_str)) + return make_badarg(env); + + char set_token_errstr[512]; + size_t extension_key_value_cnt = 0; + char **extension_key_value = NULL; + + if (extensions_str != "") + extension_key_value = split_consumer_extensions(extensions_str, &extension_key_value_cnt); + + if (rd_kafka_oauthbearer_set_token(consumer->kf, token.c_str(), lifetime * 1000, principal.c_str(), + (const char **)extension_key_value, extension_key_value_cnt, + set_token_errstr, sizeof(set_token_errstr)) != RD_KAFKA_RESP_ERR_NO_ERROR) + { + rd_kafka_oauthbearer_set_token_failure(consumer->kf, set_token_errstr); + free_consumer_extensions(extension_key_value, extension_key_value_cnt); + return ATOMS.atomError; + } + else + { + free_consumer_extensions(extension_key_value, extension_key_value_cnt); + return ATOMS.atomOk; + } +} + +ERL_NIF_TERM enif_consumer_oauthbearer_set_token_failure(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + UNUSED(argc); + + erlkaf_data* data = static_cast(enif_priv_data(env)); + enif_consumer* consumer; + + if(!enif_get_resource(env, argv[0], data->res_consumer, reinterpret_cast(&consumer))) + return make_badarg(env); + + char set_token_errstr[512]; + rd_kafka_oauthbearer_set_token_failure(consumer->kf, set_token_errstr); + return ATOMS.atomOk; +} diff --git a/c_src/erlkaf_consumer.h b/c_src/erlkaf_consumer.h index 8e95d66..0ca2dfa 100644 --- a/c_src/erlkaf_consumer.h +++ b/c_src/erlkaf_consumer.h @@ -11,5 +11,7 @@ ERL_NIF_TERM enif_consumer_queue_poll(ErlNifEnv* env, int argc, const ERL_NIF_TE ERL_NIF_TERM enif_consumer_queue_cleanup(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM enif_consumer_offset_store(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM enif_consumer_cleanup(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM enif_consumer_oauthbearer_set_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM enif_consumer_oauthbearer_set_token_failure(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); #endif // C_SRC_ERLKAF_CONSUMER_H_ diff --git a/c_src/erlkaf_nif.cc b/c_src/erlkaf_nif.cc index 41b0b01..7d45098 100755 --- a/c_src/erlkaf_nif.cc +++ b/c_src/erlkaf_nif.cc @@ -33,6 +33,7 @@ const char kAtomClientStopped[] = "client_stopped"; const char kAtomNotAvailable[] = "not_available"; const char kAtomCreateTime[] = "create_time"; const char kAtomLogAppendTime[] = "log_append_time"; +const char kAtomOauthbearerTokenRefresh[] = "oauthbearer_token_refresh"; atoms ATOMS; @@ -75,6 +76,7 @@ int on_nif_load(ErlNifEnv* env, void** priv_data, ERL_NIF_TERM load_info) ATOMS.atomNotAvailable = make_atom(env, kAtomNotAvailable); ATOMS.atomCreateTime = make_atom(env, kAtomCreateTime); ATOMS.atomLogAppendTime = make_atom(env, kAtomLogAppendTime); + ATOMS.atomOauthbearerTokenRefresh = make_atom(env, kAtomOauthbearerTokenRefresh); erlkaf_data* data = static_cast(enif_alloc(sizeof(erlkaf_data))); open_resources(env, data); @@ -122,13 +124,17 @@ static ErlNifFunc nif_funcs[] = {"producer_cleanup", 1, enif_producer_cleanup}, {"produce", 7, enif_produce}, {"get_metadata", 1, enif_get_metadata, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"producer_oauthbearer_set_token", 5, enif_producer_oauthbearer_set_token}, + {"producer_oauthbearer_set_token_failure", 1, enif_producer_oauthbearer_set_token_failure}, {"consumer_new", 4, enif_consumer_new}, {"consumer_partition_revoke_completed", 1, enif_consumer_partition_revoke_completed}, {"consumer_queue_poll", 2, enif_consumer_queue_poll}, {"consumer_queue_cleanup", 1, enif_consumer_queue_cleanup}, {"consumer_offset_store", 4, enif_consumer_offset_store}, - {"consumer_cleanup", 1, enif_consumer_cleanup} + {"consumer_cleanup", 1, enif_consumer_cleanup}, + {"consumer_oauthbearer_set_token", 5, enif_consumer_oauthbearer_set_token}, + {"consumer_oauthbearer_set_token_failure", 1, enif_consumer_oauthbearer_set_token_failure} }; ERL_NIF_INIT(erlkaf_nif, nif_funcs, on_nif_load, NULL, on_nif_upgrade, on_nif_unload) diff --git a/c_src/erlkaf_nif.h b/c_src/erlkaf_nif.h index eb67db2..caa2d4a 100644 --- a/c_src/erlkaf_nif.h +++ b/c_src/erlkaf_nif.h @@ -34,6 +34,7 @@ struct atoms ERL_NIF_TERM atomNotAvailable; ERL_NIF_TERM atomCreateTime; ERL_NIF_TERM atomLogAppendTime; + ERL_NIF_TERM atomOauthbearerTokenRefresh; }; struct erlkaf_data diff --git a/c_src/erlkaf_producer.cc b/c_src/erlkaf_producer.cc index 749585c..4e6b8a1 100644 --- a/c_src/erlkaf_producer.cc +++ b/c_src/erlkaf_producer.cc @@ -13,6 +13,9 @@ #include #include #include +#include +#include +#include namespace { @@ -65,6 +68,26 @@ int stats_callback(rd_kafka_t *rk, char *json, size_t json_len, void *opaque) return 0; } +void oauthbearer_token_refresh_callback(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) +{ + UNUSED(rk); + + enif_producer* producer = static_cast(opaque); + ErlNifEnv* env = enif_alloc_env(); + + if (oauthbearer_config == NULL) + { + enif_send(NULL, &producer->owner_pid, env, enif_make_tuple2(env, ATOMS.atomOauthbearerTokenRefresh, ATOMS.atomUndefined)); + } + else + { + ERL_NIF_TERM config = make_binary(env, oauthbearer_config, strlen(oauthbearer_config)); + enif_send(NULL, &producer->owner_pid, env, enif_make_tuple2(env, ATOMS.atomOauthbearerTokenRefresh, config)); + } + + enif_free_env(env); +} + bool populate_headers(ErlNifEnv* env, ERL_NIF_TERM headers_term, rd_kafka_headers_t* out) { ERL_NIF_TERM head; @@ -164,6 +187,8 @@ ERL_NIF_TERM enif_producer_new(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv rd_kafka_conf_set_log_cb(config.get(), logger_callback); rd_kafka_conf_set_stats_cb(config.get(), stats_callback); + rd_kafka_conf_set_oauthbearer_token_refresh_cb(config.get(), oauthbearer_token_refresh_callback); + rd_kafka_conf_enable_sasl_queue(config.get(), 1); if(has_dr_callback) rd_kafka_conf_set_dr_msg_cb(config.get(), delivery_report_callback); @@ -214,6 +239,8 @@ ERL_NIF_TERM enif_producer_set_owner(ErlNifEnv* env, int argc, const ERL_NIF_TER if(!enif_get_local_pid(env, argv[1], &producer->owner_pid)) return make_badarg(env); + rd_kafka_sasl_background_callbacks_enable(producer->kf); + return ATOMS.atomOk; } @@ -417,3 +444,105 @@ ERL_NIF_TERM enif_produce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) return ATOMS.atomOk; } + +char** split_producer_extensions(const std::string extensions_str, size_t* length) +{ + std::stringstream extensions_stream(extensions_str); + std::string extension_tmp; + std::string kv_tmp; + std::vector extensions_vector; + + while (getline(extensions_stream, extension_tmp, ',')) + { + std::stringstream kv_stream(extension_tmp); + while (getline(kv_stream, kv_tmp, '=')) + extensions_vector.push_back(kv_tmp); + } + + *length = extensions_vector.size(); + char ** extensions = new char*[*length]; + + for(size_t i = 0; i < *length; ++i) + { + extensions[i] = new char[extensions_vector[i].size() + 1]; + strcpy(extensions[i], extensions_vector[i].c_str()); + } + + return extensions; +} + +void free_producer_extensions(char** extensions, size_t length) +{ + if (extensions != nullptr) + { + for (size_t i = 0; i < length; ++i) + if (extensions[i] != nullptr) + delete[] extensions[i]; + + delete[] extensions; + } +} + +ERL_NIF_TERM enif_producer_oauthbearer_set_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + UNUSED(argc); + + std::string token; + long lifetime; + std::string principal; + std::string extensions_str; + + erlkaf_data* data = static_cast(enif_priv_data(env)); + enif_producer* producer; + + if(!enif_get_resource(env, argv[0], data->res_producer, reinterpret_cast(&producer))) + return make_badarg(env); + + if(!get_string(env, argv[1], &token)) + return make_badarg(env); + + if(!enif_get_long(env, argv[2], &lifetime)) + return make_badarg(env); + + if(!get_string(env, argv[3], &principal)) + return make_badarg(env); + + if(!get_string(env, argv[4], &extensions_str)) + return make_badarg(env); + + char set_token_errstr[512]; + size_t extension_key_value_cnt = 0; + char **extension_key_value = NULL; + + if (extensions_str != "") + extension_key_value = split_producer_extensions(extensions_str, &extension_key_value_cnt); + + if (rd_kafka_oauthbearer_set_token(producer->kf, token.c_str(), lifetime * 1000, principal.c_str(), + (const char **)extension_key_value, extension_key_value_cnt, + set_token_errstr, sizeof(set_token_errstr)) != RD_KAFKA_RESP_ERR_NO_ERROR) + { + rd_kafka_oauthbearer_set_token_failure(producer->kf, set_token_errstr); + free_producer_extensions(extension_key_value, extension_key_value_cnt); + return ATOMS.atomError; + } + else + { + free_producer_extensions(extension_key_value, extension_key_value_cnt); + return ATOMS.atomOk; + } +} + +ERL_NIF_TERM enif_producer_oauthbearer_set_token_failure(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + UNUSED(argc); + + erlkaf_data* data = static_cast(enif_priv_data(env)); + enif_producer* producer; + + if(!enif_get_resource(env, argv[0], data->res_producer, reinterpret_cast(&producer))) + return make_badarg(env); + + char set_token_errstr[512]; + rd_kafka_oauthbearer_set_token_failure(producer->kf, set_token_errstr); + return ATOMS.atomOk; +} diff --git a/c_src/erlkaf_producer.h b/c_src/erlkaf_producer.h index a51ae93..eab52f3 100644 --- a/c_src/erlkaf_producer.h +++ b/c_src/erlkaf_producer.h @@ -11,5 +11,7 @@ ERL_NIF_TERM enif_producer_set_owner(ErlNifEnv* env, int argc, const ERL_NIF_TER ERL_NIF_TERM enif_producer_cleanup(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM enif_produce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM enif_get_metadata(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM enif_producer_oauthbearer_set_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM enif_producer_oauthbearer_set_token_failure(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); #endif // C_SRC_ERLKAF_PRODUCER_H_ diff --git a/src/erlkaf_config.erl b/src/erlkaf_config.erl index 120f3a3..2e281c6 100644 --- a/src/erlkaf_config.erl +++ b/src/erlkaf_config.erl @@ -80,6 +80,8 @@ is_erlkaf_config(delivery_report_callback = K, V) -> check_callback(K, V, 2); is_erlkaf_config(stats_callback = K, V) -> check_callback(K, V, 2); +is_erlkaf_config(oauthbearer_token_refresh_callback = K, V) -> + check_callback(K, V, 1); is_erlkaf_config(queue_buffering_overflow_strategy = K, V) -> case V of local_disk_queue -> @@ -218,6 +220,18 @@ to_librdkafka_config(sasl_oauthbearer_config, V) -> {<<"sasl.oauthbearer.config">>, erlkaf_utils:to_binary(V)}; to_librdkafka_config(enable_sasl_oauthbearer_unsecure_jwt, V) -> {<<"enable.sasl.oauthbearer.unsecure.jwt">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_method, V) -> + {<<"sasl.oauthbearer.method">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_client_id, V) -> + {<<"sasl.oauthbearer.client.id">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_client_secret, V) -> + {<<"sasl.oauthbearer.client.secret">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_scope, V) -> + {<<"sasl.oauthbearer.scope">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_extensions, V) -> + {<<"sasl.oauthbearer.extensions">>, erlkaf_utils:to_binary(V)}; +to_librdkafka_config(sasl_oauthbearer_token_endpoint_url, V) -> + {<<"sasl.oauthbearer.token.endpoint.url">>, erlkaf_utils:to_binary(V)}; to_librdkafka_config(group_instance_id, V) -> {<<"group.instance.id">>, erlkaf_utils:to_binary(V)}; to_librdkafka_config(session_timeout_ms, V) -> diff --git a/src/erlkaf_consumer_callbacks.erl b/src/erlkaf_consumer_callbacks.erl index 265599f..6a2ea4c 100644 --- a/src/erlkaf_consumer_callbacks.erl +++ b/src/erlkaf_consumer_callbacks.erl @@ -11,6 +11,10 @@ -callback stats_callback(client_id(), map()) -> ok. +-callback oauthbearer_token_refresh_callback(binary()) -> + ok. + -optional_callbacks([ - stats_callback/2 + stats_callback/2, + oauthbearer_token_refresh_callback/1 ]). diff --git a/src/erlkaf_consumer_group.erl b/src/erlkaf_consumer_group.erl index 04b72cd..832040f 100644 --- a/src/erlkaf_consumer_group.erl +++ b/src/erlkaf_consumer_group.erl @@ -23,7 +23,8 @@ topics_settings = #{}, active_topics_map = #{}, stats_cb, - stats = [] + stats = [], + oauthbearer_token_refresh_cb }). start_link(ClientId, GroupId, Topics, EkClientConfig, RdkClientConfig, EkTopicConfig, RdkTopicConfig) -> @@ -42,7 +43,8 @@ init([ClientId, GroupId, Topics, EkClientConfig, RdkClientConfig, _EkTopicConfig client_id = ClientId, client_ref = ClientRef, topics_settings = maps:from_list(Topics), - stats_cb = erlkaf_utils:lookup(stats_callback, EkClientConfig) + stats_cb = erlkaf_utils:lookup(stats_callback, EkClientConfig), + oauthbearer_token_refresh_cb = erlkaf_utils:lookup(oauthbearer_token_refresh_callback, EkClientConfig) }}; Error -> {stop, Error} @@ -68,6 +70,23 @@ handle_info({stats, Stats0}, #state{stats_cb = StatsCb, client_id = ClientId} = end, {noreply, State#state{stats = Stats}}; +handle_info({oauthbearer_token_refresh, OauthBearerConfig}, #state{ + oauthbearer_token_refresh_cb = OauthbearerTokenRefreshCb, + client_id = ClientId, + client_ref = ClientRef} = State) -> + + case catch erlkaf_utils:call_oauthbearer_token_refresh_callback(OauthbearerTokenRefreshCb, OauthBearerConfig) of + {ok, Token, LifeTime, Principal} -> + erlkaf_nif:consumer_oauthbearer_set_token(ClientRef, Token, LifeTime, Principal, ""); + {ok, Token, LifeTime, Principal, Extensions} -> + erlkaf_nif:consumer_oauthbearer_set_token(ClientRef, Token, LifeTime, Principal, Extensions); + {error, Error} -> + erlkaf_nif:consumer_oauthbearer_set_token_failure(ClientRef), + ?LOG_ERROR("~p:oauthbearer_token_refresh_callback client_id: ~p error: ~p", [OauthbearerTokenRefreshCb, ClientId, Error]) + end, + + {noreply, State}; + handle_info({assign_partitions, Partitions}, #state{ client_ref = ClientRef, topics_settings = TopicsSettingsMap, diff --git a/src/erlkaf_nif.erl b/src/erlkaf_nif.erl index 5846397..bb842fc 100644 --- a/src/erlkaf_nif.erl +++ b/src/erlkaf_nif.erl @@ -13,6 +13,8 @@ producer_set_owner/2, producer_topic_new/3, produce/7, + producer_oauthbearer_set_token/5, + producer_oauthbearer_set_token_failure/1, get_metadata/1, consumer_new/4, @@ -20,7 +22,9 @@ consumer_queue_poll/2, consumer_queue_cleanup/1, consumer_offset_store/4, - consumer_cleanup/1 + consumer_cleanup/1, + consumer_oauthbearer_set_token/5, + consumer_oauthbearer_set_token_failure/1 ]). %% nif functions @@ -50,6 +54,12 @@ producer_topic_new(_ClientRef, _TopicName, _TopicConfig) -> produce(_ClientRef, _TopicRef, _Partition, _Key, _Value, _Headers, _Timestamp) -> ?NOT_LOADED. +producer_oauthbearer_set_token(_ClientRef, _Token, _LifeTime, _Principal, _Extensions) -> + ?NOT_LOADED. + +producer_oauthbearer_set_token_failure(_ClientRef) -> + ?NOT_LOADED. + get_metadata(_ClientRef) -> ?NOT_LOADED. @@ -70,3 +80,9 @@ consumer_offset_store(_ClientRef, _TopicName, _Partition, _Offset) -> consumer_cleanup(_ClientRef) -> ?NOT_LOADED. + +consumer_oauthbearer_set_token(_ClientRef, _Token, _LifeTime, _Principal, _Extensions) -> + ?NOT_LOADED. + +consumer_oauthbearer_set_token_failure(_ClientRef) -> + ?NOT_LOADED. diff --git a/src/erlkaf_producer.erl b/src/erlkaf_producer.erl index ca658f5..576371d 100644 --- a/src/erlkaf_producer.erl +++ b/src/erlkaf_producer.erl @@ -31,7 +31,8 @@ stats = [], overflow_method, pqueue, - pqueue_sch = true + pqueue_sch = true, + oauthbearer_token_refresh_cb }). start_link(ClientId, DrCallback, ErlkafConfig, ProducerRef) -> @@ -44,6 +45,7 @@ init([ClientId, DrCallback, ErlkafConfig, ProducerRef]) -> Pid = self(), OverflowStrategy = erlkaf_utils:lookup(queue_buffering_overflow_strategy, ErlkafConfig, local_disk_queue), StatsCallback = erlkaf_utils:lookup(stats_callback, ErlkafConfig), + OauthbearerTokenRefreshCb = erlkaf_utils:lookup(oauthbearer_token_refresh_callback, ErlkafConfig), ok = erlkaf_nif:producer_set_owner(ProducerRef, Pid), ok = erlkaf_cache_client:set(ClientId, ProducerRef, Pid), {ok, Queue} = erlkaf_local_queue:new(ClientId), @@ -62,7 +64,8 @@ init([ClientId, DrCallback, ErlkafConfig, ProducerRef]) -> dr_cb = DrCallback, stats_cb = StatsCallback, overflow_method = OverflowStrategy, - pqueue = Queue}}. + pqueue = Queue, + oauthbearer_token_refresh_cb = OauthbearerTokenRefreshCb}}. handle_call({queue_event, TopicName, Partition, Key, Value, Headers, Timestamp}, _From, #state{ pqueue = Queue, @@ -116,6 +119,23 @@ handle_info({stats, Stats0}, #state{stats_cb = StatsCb, client_id = ClientId} = end, {noreply, State#state{stats = Stats}}; +handle_info({oauthbearer_token_refresh, OauthBearerConfig}, #state{ + oauthbearer_token_refresh_cb = OauthbearerTokenRefreshCb, + client_id = ClientId, + ref = ClientRef} = State) -> + + case catch erlkaf_utils:call_oauthbearer_token_refresh_callback(OauthbearerTokenRefreshCb, OauthBearerConfig) of + {ok, Token, LifeTime, Principal} -> + erlkaf_nif:producer_oauthbearer_set_token(ClientRef, Token, LifeTime, Principal, ""); + {ok, Token, LifeTime, Principal, Extensions} -> + erlkaf_nif:producer_oauthbearer_set_token(ClientRef, Token, LifeTime, Principal, Extensions); + {error, Error} -> + erlkaf_nif:producer_oauthbearer_set_token_failure(ClientRef), + ?LOG_ERROR("~p:oauthbearer_token_refresh_callback client_id: ~p error: ~p", [OauthbearerTokenRefreshCb, ClientId, Error]) + end, + + {noreply, State}; + handle_info(Info, State) -> ?LOG_ERROR("received unknown message: ~p", [Info]), {noreply, State}. diff --git a/src/erlkaf_utils.erl b/src/erlkaf_utils.erl index 54eb67c..7eaf81f 100644 --- a/src/erlkaf_utils.erl +++ b/src/erlkaf_utils.erl @@ -11,6 +11,7 @@ safe_call/2, safe_call/3, call_stats_callback/3, + call_oauthbearer_token_refresh_callback/2, parralel_exec/2 ]). @@ -84,10 +85,17 @@ safe_call(Receiver, Message, Timeout) -> call_stats_callback(undefined, _ClientId, _Stats) -> ok; -call_stats_callback(C, ClientId, Stats) when is_function(C, 2) -> - C(ClientId, Stats); -call_stats_callback(C, ClientId, Stats) -> - C:stats_callback(ClientId, Stats). +call_stats_callback(C, OAuthBearerConfig, Stats) when is_function(C, 2) -> + C(OAuthBearerConfig, Stats); +call_stats_callback(C, OAuthBearerConfig, Stats) -> + C:stats_callback(OAuthBearerConfig, Stats). + +call_oauthbearer_token_refresh_callback(undefined, _OauthbearerConfig) -> + ok; +call_oauthbearer_token_refresh_callback(C, OauthbearerConfig) when is_function(C, 1) -> + C(OauthbearerConfig); +call_oauthbearer_token_refresh_callback(C, OauthbearerConfig) -> + C:oauthbearer_token_refresh_callback(OauthbearerConfig). parralel_exec(Fun, List) -> Parent = self(),