diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 6cb6c96909da..e38a69dcc216 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -53,6 +53,32 @@ static const char *table_prefixes[] = { NULL, "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:", }; +/// Parse a Redis string into a TablePubsub channel. +TablePubsub ParseTablePubsub(const RedisModuleString *pubsub_channel_str) { + long long pubsub_channel_long; + RAY_CHECK(RedisModule_StringToLongLong( + pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK) + << "Pubsub channel must be a valid TablePubsub"; + auto pubsub_channel = static_cast(pubsub_channel_long); + RAY_CHECK(pubsub_channel >= TablePubsub_MIN && + pubsub_channel <= TablePubsub_MAX) + << "Pubsub channel must be a valid TablePubsub"; + return pubsub_channel; +} + +/// Format a pubsub channel for a specific key. pubsub_channel_str should +/// contain a valid TablePubsub. +RedisModuleString *FormatPubsubChannel( + RedisModuleCtx *ctx, + const RedisModuleString *pubsub_channel_str, + const RedisModuleString *id) { + // Format the pubsub channel enum to a string. TablePubsub_MAX should be more + // than enough digits, but add 1 just in case for the null terminator. + char pubsub_channel[TablePubsub_MAX + 1]; + sprintf(pubsub_channel, "%d", ParseTablePubsub(pubsub_channel_str)); + return RedisString_Format(ctx, "%s:%S", pubsub_channel, id); +} + // TODO(swang): This helper function should be deprecated by the version below, // which uses enums for table prefixes. RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, @@ -83,6 +109,23 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, return OpenPrefixedKey(ctx, table_prefixes[prefix], keyname, mode); } +/// Open the key used to store the channels that should be published to when an +/// update happens at the given keyname. +RedisModuleKey *OpenBroadcastKey(RedisModuleCtx *ctx, + RedisModuleString *pubsub_channel_str, + RedisModuleString *keyname, + int mode) { + RedisModuleString *channel = + FormatPubsubChannel(ctx, pubsub_channel_str, keyname); + RedisModuleString *prefixed_keyname = + RedisString_Format(ctx, "BCAST:%S", channel); + RedisModuleKey *key = + (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); + RedisModule_FreeString(ctx, prefixed_keyname); + RedisModule_FreeString(ctx, channel); + return key; +} + /** * This is a helper method to convert a redis module string to a flatbuffer * string. @@ -411,8 +454,181 @@ bool PublishObjectNotification(RedisModuleCtx *ctx, return true; } -// This is a temporary redis command that will be removed once +// NOTE(pcmoritz): This is a temporary redis command that will be removed once // the GCS uses https://github.com/pcmoritz/credis. +int TaskTableAdd(RedisModuleCtx *ctx, + RedisModuleString *id, + RedisModuleString *data) { + const char *buf = RedisModule_StringPtrLen(data, NULL); + auto message = flatbuffers::GetRoot(buf); + + if (message->scheduling_state() == SchedulingState_WAITING || + message->scheduling_state() == SchedulingState_SCHEDULED) { + /* Build the PUBLISH topic and message for task table subscribers. The + * topic + * is a string in the format "TASK_PREFIX::". + * The + * message is a serialized SubscribeToTasksReply flatbuffer object. */ + std::string state = std::to_string(message->scheduling_state()); + RedisModuleString *publish_topic = RedisString_Format( + ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(), + sizeof(DBClientID), state.c_str()); + + /* Construct the flatbuffers object for the payload. */ + flatbuffers::FlatBufferBuilder fbb; + /* Create the flatbuffers message. */ + auto msg = CreateTaskReply( + fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(), + fbb.CreateString(message->scheduler_id()), + fbb.CreateString(message->execution_dependencies()), + fbb.CreateString(message->task_info()), message->spillback_count(), + true /* not used */); + fbb.Finish(msg); + + RedisModuleString *publish_message = RedisModule_CreateString( + ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); + + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); + + /* See how many clients received this publish. */ + long long num_clients = RedisModule_CallReplyInteger(reply); + RAY_CHECK(num_clients <= 1) << "Published to " << num_clients + << " clients."; + + RedisModule_FreeString(ctx, publish_message); + RedisModule_FreeString(ctx, publish_topic); + } + return RedisModule_ReplyWithSimpleString(ctx, "OK"); +} + +// TODO(swang): Implement the client table as an append-only log so that we +// don't need this special case for client table publication. +int ClientTableAdd(RedisModuleCtx *ctx, + RedisModuleString *pubsub_channel_str, + RedisModuleString *data) { + const char *buf = RedisModule_StringPtrLen(data, NULL); + auto client_data = flatbuffers::GetRoot(buf); + + RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey( + ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE); + // If this is a client addition, send all previous notifications, in order. + // NOTE(swang): This will go to all clients, so some clients will get + // duplicate notifications. + if (client_data->is_insertion() && + RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) { + // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. + CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( + clients_key, REDISMODULE_NEGATIVE_INFINITE, + REDISMODULE_POSITIVE_INFINITE, 1, 1), + "Unable to initialize zset iterator"); + do { + RedisModuleString *message = + RedisModule_ZsetRangeCurrentElement(clients_key, NULL); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message); + if (reply == NULL) { + RedisModule_CloseKey(clients_key); + return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + } + } while (RedisModule_ZsetRangeNext(clients_key)); + } + + // Append this notification to the past notifications so that it will get + // sent to new clients in the future. + size_t index = RedisModule_ValueLength(clients_key); + // Serialize the notification to send. + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateGcsNotification(fbb, fbb.CreateString(""), + RedisStringToFlatbuf(fbb, data)); + fbb.Finish(message); + auto notification = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), + fbb.GetSize()); + RedisModule_ZsetAdd(clients_key, index, notification, NULL); + // Publish the notification about this client. + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, notification); + RedisModule_FreeString(ctx, notification); + if (reply == NULL) { + RedisModule_CloseKey(clients_key); + return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + } + + RedisModule_CloseKey(clients_key); + return RedisModule_ReplyWithSimpleString(ctx, "OK"); +} + +/// Publish a notification for a new entry at a key. This publishes a +/// notification to all subscribers of the table, as well as every client that +/// has requested notifications for this key. +/// +/// \param pubsub_channel_str The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific +/// client, the channel name should be :. +/// \param id The ID of the key that the notification is about. +/// \param data The data to publish. +/// \return OK if there is no error during a publish. +int PublishTableAdd(RedisModuleCtx *ctx, + RedisModuleString *pubsub_channel_str, + RedisModuleString *id, + RedisModuleString *data) { + // Serialize the notification to send. + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id), + RedisStringToFlatbuf(fbb, data)); + fbb.Finish(message); + + // Write the data back to any subscribers that are listening to all table + // notifications. + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, + fbb.GetBufferPointer(), fbb.GetSize()); + if (reply == NULL) { + return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + } + + // Publish the data to any clients who requested notifications on this key. + RedisModuleKey *notification_key = OpenBroadcastKey( + ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); + if (RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY) { + // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. + CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( + notification_key, REDISMODULE_NEGATIVE_INFINITE, + REDISMODULE_POSITIVE_INFINITE, 1, 1), + "Unable to initialize zset iterator"); + for (; !RedisModule_ZsetRangeEndReached(notification_key); + RedisModule_ZsetRangeNext(notification_key)) { + RedisModuleString *client_channel = + RedisModule_ZsetRangeCurrentElement(notification_key, NULL); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, + fbb.GetBufferPointer(), fbb.GetSize()); + if (reply == NULL) { + RedisModule_CloseKey(notification_key); + return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); + } + } + } + RedisModule_CloseKey(notification_key); + return RedisModule_ReplyWithSimpleString(ctx, "OK"); +} + +/// Add an entry at a key. This overwrites any existing data at the key. +/// Publishes a notification about the update to all subscribers, if a pubsub +/// channel is provided. +/// +/// This is called from a client with the command: +// +/// RAY.TABLE_ADD +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific +/// client, the channel name should be :. +/// \param id The ID of the key to set. +/// \param data The data to insert at the key. +/// \return The current value at the key, or OK if there is no value. int TableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { @@ -431,108 +647,22 @@ int TableAdd_RedisCommand(RedisModuleCtx *ctx, RedisModule_StringSet(key, data); RedisModule_CloseKey(key); - // Get the requested pubsub channel. - long long pubsub_channel_long; - RAY_CHECK(RedisModule_StringToLongLong( - pubsub_channel_str, &pubsub_channel_long) == REDISMODULE_OK) - << "Pubsub channel must be a valid TablePubsub"; - auto pubsub_channel = static_cast(pubsub_channel_long); - RAY_CHECK(pubsub_channel >= TablePubsub_MIN && - pubsub_channel <= TablePubsub_MAX) - << "Pubsub channel must be a valid TablePubsub"; - // Publish a message on the requested pubsub channel if necessary. + TablePubsub pubsub_channel = ParseTablePubsub(pubsub_channel_str); if (pubsub_channel == TablePubsub_TASK) { - const char *buf = RedisModule_StringPtrLen(data, NULL); - auto message = flatbuffers::GetRoot(buf); - - if (message->scheduling_state() == SchedulingState_WAITING || - message->scheduling_state() == SchedulingState_SCHEDULED) { - /* Build the PUBLISH topic and message for task table subscribers. The - * topic - * is a string in the format "TASK_PREFIX::". - * The - * message is a serialized SubscribeToTasksReply flatbuffer object. */ - std::string state = std::to_string(message->scheduling_state()); - RedisModuleString *publish_topic = RedisString_Format( - ctx, "%s%b:%s", TASK_PREFIX, message->scheduler_id()->str().data(), - sizeof(DBClientID), state.c_str()); - - /* Construct the flatbuffers object for the payload. */ - flatbuffers::FlatBufferBuilder fbb; - /* Create the flatbuffers message. */ - auto msg = CreateTaskReply( - fbb, RedisStringToFlatbuf(fbb, id), message->scheduling_state(), - fbb.CreateString(message->scheduler_id()), - fbb.CreateString(message->execution_dependencies()), - fbb.CreateString(message->task_info()), message->spillback_count(), - true /* not used */); - fbb.Finish(msg); - - RedisModuleString *publish_message = RedisModule_CreateString( - ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); - - RedisModuleCallReply *reply = RedisModule_Call( - ctx, "PUBLISH", "ss", publish_topic, publish_message); - - /* See how many clients received this publish. */ - long long num_clients = RedisModule_CallReplyInteger(reply); - RAY_CHECK(num_clients <= 1) << "Published to " << num_clients - << " clients."; - - RedisModule_FreeString(ctx, publish_message); - RedisModule_FreeString(ctx, publish_topic); - } + // Publish the task to its subscribers. + // TODO(swang): This is only necessary for legacy Ray and should be removed + // once we switch to using the new GCS API for the task table. + return TaskTableAdd(ctx, id, data); } else if (pubsub_channel == TablePubsub_CLIENT) { - const char *buf = RedisModule_StringPtrLen(data, NULL); - auto client_data = flatbuffers::GetRoot(buf); - - RedisModuleKey *clients_key = (RedisModuleKey *) RedisModule_OpenKey( - ctx, pubsub_channel_str, REDISMODULE_READ | REDISMODULE_WRITE); - // If this is a client addition, send all previous notifications, in order. - // NOTE(swang): This will go to all clients, so some clients will get - // duplicate notifications. - if (client_data->is_insertion() && - RedisModule_KeyType(clients_key) != REDISMODULE_KEYTYPE_EMPTY) { - // NOTE(swang): Sets are not implemented yet, so we use ZSETs instead. - CHECK_ERROR(RedisModule_ZsetFirstInScoreRange( - clients_key, REDISMODULE_NEGATIVE_INFINITE, - REDISMODULE_POSITIVE_INFINITE, 1, 1), - "Unable to initialize zset iterator"); - do { - RedisModuleString *message = - RedisModule_ZsetRangeCurrentElement(clients_key, NULL); - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, message); - if (reply == NULL) { - RedisModule_CloseKey(clients_key); - RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - } while (RedisModule_ZsetRangeNext(clients_key)); - } - - // Append this notification to the past notifications so that it will get - // sent to new clients in the future. - size_t index = RedisModule_ValueLength(key); - RedisModule_ZsetAdd(clients_key, index, data, NULL); - // Publish the notification about this client. - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data); - if (reply == NULL) { - RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } - - RedisModule_CloseKey(clients_key); + // Publish all previous client table additions to the new client. + return ClientTableAdd(ctx, pubsub_channel_str, data); } else if (pubsub_channel != TablePubsub_NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. - RedisModuleCallReply *reply = - RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data); - if (reply == NULL) { - RedisModule_ReplyWithError(ctx, "error during PUBLISH"); - } + return PublishTableAdd(ctx, pubsub_channel_str, id, data); + } else { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); } - - return RedisModule_ReplyWithSimpleString(ctx, "OK"); } // This is a temporary redis command that will be removed once @@ -561,6 +691,114 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, return REDISMODULE_OK; } +/// Request notifications for changes to a key. Returns the current value or +/// values at the key. Notifications will be sent to the requesting client for +/// every subsequent TABLE_ADD to the key. +/// +/// This is called from a client with the command: +// +/// RAY.TABLE_REQUEST_NOTIFICATIONS +/// +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific +/// client, the channel name should be :. +/// \param id The ID of the key to publish notifications for. +/// \param client_id The ID of the client that is being notified. +/// \return The current value at the key, or OK if there is no value. +int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *client_id = argv[4]; + RedisModuleString *client_channel = + FormatPubsubChannel(ctx, pubsub_channel_str, client_id); + + // Add this client to the set of clients that should be notified when there + // are changes to the key. + RedisModuleKey *notification_key = OpenBroadcastKey( + ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); + CHECK_ERROR(RedisModule_ZsetAdd(notification_key, 0.0, client_channel, NULL), + "ZsetAdd failed."); + RedisModule_CloseKey(notification_key); + RedisModule_FreeString(ctx, client_channel); + + // Return the current value at the key, if any, to the client that requested + // a notification. + RedisModuleKey *table_key = + OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ); + if (table_key != nullptr) { + // Serialize the notification to send. + size_t data_len = 0; + char *data_buf = + RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateGcsNotification(fbb, RedisStringToFlatbuf(fbb, id), + fbb.CreateString(data_buf, data_len)); + fbb.Finish(message); + + int result = RedisModule_ReplyWithStringBuffer( + ctx, reinterpret_cast(fbb.GetBufferPointer()), + fbb.GetSize()); + RedisModule_CloseKey(table_key); + return result; + } else { + RedisModule_CloseKey(table_key); + RedisModule_ReplyWithSimpleString(ctx, "OK"); + return REDISMODULE_OK; + } +} + +/// Cancel notifications for changes to a key. The client will no longer +/// receive notifications for this key. +/// +/// This is called from a client with the command: +// +/// RAY.TABLE_CANCEL_NOTIFICATIONS +/// +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for +/// this key should be published to. If publishing to a specific client, +/// then the channel name should be :. +/// \param id The ID of the key to publish notifications for. +/// \param client_id The ID of the client that is being notified. +/// \return OK if the requesting client was removed, or an error if the client +/// was not found. +int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc < 5) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *client_id = argv[4]; + RedisModuleString *client_channel = + FormatPubsubChannel(ctx, pubsub_channel_str, client_id); + + // Remove this client from the set of clients that should be notified when + // there are changes to the key. + RedisModuleKey *notification_key = OpenBroadcastKey( + ctx, pubsub_channel_str, id, REDISMODULE_READ | REDISMODULE_WRITE); + RAY_CHECK(RedisModule_KeyType(notification_key) != REDISMODULE_KEYTYPE_EMPTY); + int deleted; + RedisModule_ZsetRem(notification_key, client_channel, &deleted); + RAY_CHECK(deleted); + RedisModule_CloseKey(notification_key); + + RedisModule_ReplyWithSimpleString(ctx, "OK"); + return REDISMODULE_OK; +} + bool is_nil(const std::string &data) { RAY_CHECK(data.size() == kUniqueIDSize); const uint8_t *d = reinterpret_cast(data.data()); @@ -1429,6 +1667,18 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", + TableRequestNotifications_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.table_cancel_notifications", + TableCancelNotifications_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", TableTestAndUpdate_RedisCommand, "write", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index 3b4223d91ca4..be7b1aee86d1 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -1332,10 +1332,10 @@ void log_object_hash_mismatch_error_result_callback(ObjectID object_id, RAY_CHECK_OK(state->gcs_client.task_table().Lookup( ray::JobID::nil(), task_id, [user_context](gcs::AsyncGcsClient *, const TaskID &, - std::shared_ptr t) { + const TaskTableDataT &t) { Task *task = Task_alloc( - t->task_info.data(), t->task_info.size(), t->scheduling_state, - DBClientID::from_binary(t->scheduler_id), std::vector()); + t.task_info.data(), t.task_info.size(), t.scheduling_state, + DBClientID::from_binary(t.scheduler_id), std::vector()); log_object_hash_mismatch_error_task_callback(task, user_context); Task_free(task); }, diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 075142ff08a5..eeb6b143792b 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -21,7 +21,7 @@ static inline void flushall_redis(void) { class TestGcs : public ::testing::Test { public: - TestGcs() { + TestGcs() : num_callbacks_(0) { client_ = std::make_shared(); ClientTableDataT client_info; client_info.client_id = ClientID::from_random().binary(); @@ -42,7 +42,12 @@ class TestGcs : public ::testing::Test { virtual void Stop() = 0; + int64_t NumCallbacks() const { return num_callbacks_; } + + void IncrementNumCallbacks() { num_callbacks_++; } + protected: + int64_t num_callbacks_; std::shared_ptr client_; JobID job_id_; }; @@ -87,14 +92,14 @@ class TestGcsWithAsio : public TestGcs { }; void ObjectAdded(gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { - ASSERT_EQ(data->managers, std::vector({"A", "B"})); + const ObjectTableDataT &data) { + ASSERT_EQ(data.managers, std::vector({"A", "B"})); } void Lookup(gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { + const ObjectTableDataT &data) { // Check that the object entry was added. - ASSERT_EQ(data->managers, std::vector({"A", "B"})); + ASSERT_EQ(data.managers, std::vector({"A", "B"})); test->Stop(); } @@ -126,14 +131,37 @@ TEST_F(TestGcsWithAsio, TestObjectTable) { TestObjectTable(job_id_, client_); } +void TestLookupFailure(const JobID &job_id, std::shared_ptr client) { + auto object_id = ObjectID::from_random(); + // Looking up an empty object ID should call the failure callback. + auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + test->Stop(); + }; + RAY_CHECK_OK( + client->object_table().Lookup(job_id, object_id, nullptr, failure_callback)); + // Run the event loop. The loop will only stop if the failure callback is + // called. + test->Start(); +} + +TEST_F(TestGcsWithAe, TestLookupFailure) { + test = this; + TestLookupFailure(job_id_, client_); +} + +TEST_F(TestGcsWithAsio, TestLookupFailure) { + test = this; + TestLookupFailure(job_id_, client_); +} + void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - std::shared_ptr data) { - ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED); + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED); } void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - std::shared_ptr data) { - ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED); + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED); } void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { @@ -141,8 +169,8 @@ void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { } void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - std::shared_ptr data) { - ASSERT_EQ(data->scheduling_state, SchedulingState_LOST); + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState_LOST); test->Stop(); } @@ -153,8 +181,8 @@ void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) void TaskUpdateCallback(gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskTableDataT &task, bool updated) { - RAY_CHECK_OK(client->task_table().Lookup( - DriverID::nil(), task_id, &TaskLookupAfterUpdate, &TaskLookupAfterUpdateFailure)); + RAY_CHECK_OK(client->task_table().Lookup(DriverID::nil(), task_id, + &TaskLookupAfterUpdate, &TaskLookupFailure)); } void TestTaskTable(const JobID &job_id, std::shared_ptr client) { @@ -189,28 +217,40 @@ TEST_F(TestGcsWithAsio, TestTaskTable) { TestTaskTable(job_id_, client_); } -void ObjectTableSubscribed(gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { - test->Stop(); -} - void TestSubscribeAll(const JobID &job_id, std::shared_ptr client) { - // Subscribe to all object table notifications. The registered callback for - // notifications will check whether the object below is added. - RAY_CHECK_OK(client->object_table().Subscribe(job_id, ClientID::nil(), &Lookup, - &ObjectTableSubscribed)); - // Run the event loop. The loop will only stop if the subscription succeeds. - test->Start(); - - // We have subscribed. Add an object table entry. - auto data = std::make_shared(); - data->managers.push_back("A"); - data->managers.push_back("B"); ObjectID object_id = ObjectID::from_random(); - RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded)); + // Callback for a notification. + auto notification_callback = [object_id]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) { + ASSERT_EQ(id, object_id); + // Check that the object entry was added. + ASSERT_EQ(data.managers, std::vector({"A", "B"})); + test->IncrementNumCallbacks(); + test->Stop(); + }; + + // Callback for subscription success. This should only be called once. + auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) { + test->IncrementNumCallbacks(); + // We have subscribed. Add an object table entry. + auto data = std::make_shared(); + data->managers.push_back("A"); + data->managers.push_back("B"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, &ObjectAdded)); + }; + + // Subscribe to all object table notifications. Once we have successfully + // subscribed, we will add an object and check that we get notified of the + // operation. + RAY_CHECK_OK(client->object_table().Subscribe( + job_id, ClientID::nil(), notification_callback, subscribe_callback)); + // Run the event loop. The loop will only stop if the registered subscription // callback is called (or an assertion failure). test->Start(); + // Check that we received one callback for subscription success and one for + // the Add notification. + ASSERT_EQ(test->NumCallbacks(), 2); } TEST_F(TestGcsWithAe, TestSubscribeAll) { @@ -223,11 +263,152 @@ TEST_F(TestGcsWithAsio, TestSubscribeAll) { TestSubscribeAll(job_id_, client_); } +void TestSubscribeId(const JobID &job_id, std::shared_ptr client) { + // Add an object table entry. + ObjectID object_id1 = ObjectID::from_random(); + auto data1 = std::make_shared(); + data1->managers.push_back("A"); + data1->managers.push_back("B"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr)); + + // Add a second object table entry. + ObjectID object_id2 = ObjectID::from_random(); + auto data2 = std::make_shared(); + data2->managers.push_back("C"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr)); + + // The callback for subscription success. Once we've subscribed, request + // notifications for the second object that was added. + auto subscribe_callback = [job_id, object_id2](gcs::AsyncGcsClient *client) { + test->IncrementNumCallbacks(); + // Request notifications for the second object. Since we already added the + // entry to the table, we should receive an initial notification for its + // current value. + RAY_CHECK_OK(client->object_table().RequestNotifications( + job_id, object_id2, client->client_table().GetLocalClientId())); + // Overwrite the entry for the object. We should receive a second + // notification for its new value. + auto data = std::make_shared(); + data->managers.push_back("C"); + data->managers.push_back("D"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr)); + }; + + // The callback for a notification from the object table. This should only be + // received for the object that we requested notifications for. + auto notification_callback = [data2, object_id2]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) { + ASSERT_EQ(id, object_id2); + // Check that we got a notification for the correct object. + ASSERT_EQ(data.managers.front(), "C"); + test->IncrementNumCallbacks(); + // Stop the loop once we've received notifications for both writes to the + // object key. + if (test->NumCallbacks() == 3) { + test->Stop(); + } + }; + + RAY_CHECK_OK( + client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); + + // Run the event loop. The loop will only stop if the registered subscription + // callback is called for both writes to the object key. + test->Start(); + // Check that we received one callback for subscription success and two + // callbacks for the Add notifications. + ASSERT_EQ(test->NumCallbacks(), 3); +} + +TEST_F(TestGcsWithAe, TestSubscribeId) { + test = this; + TestSubscribeId(job_id_, client_); +} + +TEST_F(TestGcsWithAsio, TestSubscribeId) { + test = this; + TestSubscribeId(job_id_, client_); +} + +void TestSubscribeCancel(const JobID &job_id, + std::shared_ptr client) { + // Write the object table once. + ObjectID object_id = ObjectID::from_random(); + auto data = std::make_shared(); + data->managers.push_back("A"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); + + // The callback for subscription success. Once we've subscribed, request + // notifications for the second object that was added. + auto subscribe_callback = [job_id, object_id](gcs::AsyncGcsClient *client) { + test->IncrementNumCallbacks(); + // Request notifications for the object. We should receive a notification + // for the current value at the key. + RAY_CHECK_OK(client->object_table().RequestNotifications( + job_id, object_id, client->client_table().GetLocalClientId())); + // Cancel notifications. + RAY_CHECK_OK(client->object_table().CancelNotifications( + job_id, object_id, client->client_table().GetLocalClientId())); + // Write the object table entry twice. Since we canceled notifications, we + // should not get notifications for either of these writes. + auto data = std::make_shared(); + data->managers.push_back("B"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); + data = std::make_shared(); + data->managers.push_back("C"); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); + // Request notifications for the object again. We should only receive a + // notification for the current value at the key. + RAY_CHECK_OK(client->object_table().RequestNotifications( + job_id, object_id, client->client_table().GetLocalClientId())); + }; + + // The callback for a notification from the object table. + auto notification_callback = [object_id]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) { + ASSERT_EQ(id, object_id); + // Check that we only receive notifications for the key when we have + // requested notifications for it. We should not get a notification for the + // entry that began with "B" since we canceled notifications then. + if (test->NumCallbacks() == 1) { + ASSERT_EQ(data.managers.front(), "A"); + } else { + ASSERT_EQ(data.managers.front(), "C"); + } + test->IncrementNumCallbacks(); + if (test->NumCallbacks() == 3) { + test->Stop(); + } + }; + + RAY_CHECK_OK( + client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); + + // Run the event loop. The loop will only stop if the registered subscription + // callback is called (or an assertion failure). + test->Start(); + // Check that we received one callback for subscription success and two + // callbacks for the Add notifications. + ASSERT_EQ(test->NumCallbacks(), 3); +} + +TEST_F(TestGcsWithAe, TestSubscribeCancel) { + test = this; + TestSubscribeCancel(job_id_, client_); +} + +TEST_F(TestGcsWithAsio, TestSubscribeCancel) { + test = this; + TestSubscribeCancel(job_id_, client_); +} + void ClientTableNotification(gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data, bool is_insertion) { + const ClientTableDataT &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); - ASSERT_EQ(ClientID::from_binary(data->client_id), added_id); - ASSERT_EQ(data->is_insertion, is_insertion); + ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); + ASSERT_EQ(data.is_insertion, is_insertion); auto cached_client = client->client_table().GetClient(added_id); ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id); @@ -239,8 +420,7 @@ void TestClientTableConnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { + [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); @@ -260,13 +440,11 @@ void TestClientTableDisconnect(const JobID &job_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { + [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const UniqueID &id, - std::shared_ptr data) { + [](gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index e6483179bf50..e9aa5169ac1c 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -21,6 +21,11 @@ enum TablePubsub:int { ACTOR } +table GcsNotification { + id: string; + data: string; +} + table FunctionTableData { language: Language; name: string; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 67e6b4aef331..3a5cf38f9e18 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -11,6 +11,22 @@ extern "C" { // TODO(pcm): Integrate into the C++ tree. #include "state/ray_config.h" +namespace { + +/// A helper function to call the callback and delete it from the callback +/// manager if necessary. +void ProcessCallback(int64_t callback_index, const std::vector &data) { + if (callback_index >= 0) { + bool delete_callback = + ray::gcs::RedisCallbackManager::instance().get(callback_index)(data); + // Delete the callback if necessary. + if (delete_callback) { + ray::gcs::RedisCallbackManager::instance().remove(callback_index); + } + } +} +} + namespace ray { namespace gcs { @@ -24,24 +40,25 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) { } int64_t callback_index = reinterpret_cast(privdata); redisReply *reply = reinterpret_cast(r); - std::string data = ""; - if (reply->type == REDIS_REPLY_NIL) { - // Respond with blank string, which triggers a failure callback for lookups. - } else if (reply->type == REDIS_REPLY_STRING) { - data = std::string(reply->str, reply->len); - } else if (reply->type == REDIS_REPLY_ARRAY) { - reply = reply->element[reply->elements - 1]; - data = std::string(reply->str, reply->len); - } else if (reply->type == REDIS_REPLY_STATUS) { - } else if (reply->type == REDIS_REPLY_ERROR) { + std::vector data; + // Parse the response. + switch (reply->type) { + case (REDIS_REPLY_NIL): { + // Do not add any data for a nil response. + } break; + case (REDIS_REPLY_STRING): { + data.push_back(std::string(reply->str, reply->len)); + } break; + case (REDIS_REPLY_STATUS): { + } break; + case (REDIS_REPLY_ERROR): { RAY_LOG(ERROR) << "Redis error " << reply->str; - } else { + } break; + default: RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " << reply->str; } - RedisCallbackManager::instance().get(callback_index)(data); - // Delete the callback. - RedisCallbackManager::instance().remove(callback_index); + ProcessCallback(callback_index, data); } void SubscribeRedisCallback(void *c, void *r, void *privdata) { @@ -50,31 +67,35 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) { } int64_t callback_index = reinterpret_cast(privdata); redisReply *reply = reinterpret_cast(r); - std::string data = ""; - if (reply->type == REDIS_REPLY_ARRAY) { - // Parse the message. + std::vector data; + // Parse the response. + switch (reply->type) { + case (REDIS_REPLY_ARRAY): { + // Parse the published message. redisReply *message_type = reply->element[0]; if (strcmp(message_type->str, "subscribe") == 0) { - // If the message is for the initial subscription call, do not fill in - // data. + // If the message is for the initial subscription call, return the empty + // string as a response to signify that subscription was successful. + data.push_back(""); } else if (strcmp(message_type->str, "message") == 0) { // If the message is from a PUBLISH, make sure the data is nonempty. redisReply *message = reply->element[reply->elements - 1]; - data = std::string(message->str, message->len); - RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel"; + auto notification = std::string(message->str, message->len); + RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel"; + data.push_back(notification); } else { RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str; } - // NOTE(swang): We do not delete the callback after calling it since there - // may be more subscription messages. - RedisCallbackManager::instance().get(callback_index)(data); - } else if (reply->type == REDIS_REPLY_ERROR) { + } break; + case (REDIS_REPLY_ERROR): { RAY_LOG(ERROR) << "Redis error " << reply->str; - } else { + } break; + default: RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " << reply->str; } + ProcessCallback(callback_index, data); } int64_t RedisCallbackManager::add(const RedisCallback &function) { @@ -161,8 +182,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { } Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, - uint8_t *data, int64_t length, const TablePrefix prefix, - const TablePubsub pubsub_channel, int64_t callback_index) { + const uint8_t *data, int64_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, + int64_t callback_index) { if (length > 0) { std::string redis_command = command + " %d %d %b %b"; int status = redisAsyncCommand( @@ -200,7 +222,6 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel); } else { // Subscribe only to messages sent to this client. - // TODO(swang): Nobody sends on this channel yet. std::string redis_command = "SUBSCRIBE %d:%b"; status = redisAsyncCommand( subscribe_context_, reinterpret_cast(&SubscribeRedisCallback), diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index af0117d7dbbc..5b90265b0fbb 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -21,7 +21,10 @@ namespace gcs { class RedisCallbackManager { public: - using RedisCallback = std::function; + /// Every callback should take in a vector of the results from the Redis + /// operation and return a bool indicating whether the callback should be + /// deleted once called. + using RedisCallback = std::function &)>; static RedisCallbackManager &instance() { static RedisCallbackManager instance; @@ -50,7 +53,7 @@ class RedisContext { ~RedisContext(); Status Connect(const std::string &address, int port); Status AttachToEventLoop(aeEventLoop *loop); - Status RunAsync(const std::string &command, const UniqueID &id, uint8_t *data, + Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, int64_t callback_index); Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 7426e80d7d2a..e84e8d1d8814 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -1,36 +1,141 @@ #include "ray/gcs/tables.h" +#include "common_protocol.h" #include "ray/gcs/client.h" namespace ray { namespace gcs { -void ClientTable::RegisterClientAddedCallback(const Callback &callback) { +template +Status Table::Add(const JobID &job_id, const ID &id, + std::shared_ptr data, const Callback &done) { + auto d = std::shared_ptr( + new CallbackData({id, data, done, nullptr, nullptr, this, client_})); + int64_t callback_index = + RedisCallbackManager::instance().add([d](const std::vector &data) { + if (d->callback != nullptr) { + (d->callback)(d->client, d->id, *d->data); + } + return true; + }); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, data.get())); + return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(), + prefix_, pubsub_channel_, callback_index); +} + +template +Status Table::Lookup(const JobID &job_id, const ID &id, const Callback &lookup, + const FailureCallback &failure) { + auto d = std::shared_ptr( + new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_})); + int64_t callback_index = + RedisCallbackManager::instance().add([d](const std::vector &data) { + if (data.empty()) { + if (d->failure != nullptr) { + (d->failure)(d->client, d->id); + } + } else { + RAY_CHECK(data.size() == 1); + if (d->callback != nullptr) { + DataT result; + auto root = flatbuffers::GetRoot(data[0].data()); + root->UnPackTo(&result); + (d->callback)(d->client, d->id, result); + } + } + return true; + }); + std::vector nil; + return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_, + pubsub_channel_, callback_index); +} + +template +Status Table::Subscribe(const JobID &job_id, const ClientID &client_id, + const Callback &subscribe, + const SubscriptionCallback &done) { + RAY_CHECK(subscribe_callback_index_ == -1) + << "Client called Subscribe twice on the same table"; + auto d = std::shared_ptr( + new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_})); + int64_t callback_index = RedisCallbackManager::instance().add( + [this, d](const std::vector &data) { + if (data.size() == 1 && data[0] == "") { + // No notification data is provided. This is the callback for the + // initial subscription request. + if (d->subscription_callback != nullptr) { + (d->subscription_callback)(d->client); + } + } else { + // Data is provided. This is the callback for a message. + RAY_CHECK(data.size() == 1); + if (d->callback != nullptr) { + // Parse the notification. + auto notification = flatbuffers::GetRoot(data[0].data()); + ID id = UniqueID::nil(); + if (notification->id()->size() > 0) { + id = from_flatbuf(*notification->id()); + } + DataT result; + auto root = flatbuffers::GetRoot(notification->data()->data()); + root->UnPackTo(&result); + (d->callback)(d->client, id, result); + } + } + // We do not delete the callback after calling it since there may be + // more subscription messages. + return false; + }); + subscribe_callback_index_ = callback_index; + return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index); +} + +template +Status Table::RequestNotifications(const JobID &job_id, const ID &id, + const ClientID &client_id) { + RAY_CHECK(subscribe_callback_index_ >= 0) + << "Client requested notifications on a key before Subscribe completed"; + return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(), + client_id.size(), prefix_, pubsub_channel_, + subscribe_callback_index_); +} + +template +Status Table::CancelNotifications(const JobID &job_id, const ID &id, + const ClientID &client_id) { + RAY_CHECK(subscribe_callback_index_ >= 0) + << "Client canceled notifications on a key before Subscribe completed"; + return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(), + client_id.size(), prefix_, pubsub_channel_, + /*callback_index=*/-1); +} + +void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.is_nil() && entry.second.is_insertion) { - auto data = std::make_shared(entry.second); - client_added_callback_(client_, entry.first, data); + client_added_callback_(client_, ClientID::nil(), entry.second); } } } -void ClientTable::RegisterClientRemovedCallback(const Callback &callback) { +void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callback) { client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.is_nil() && !entry.second.is_insertion) { - auto data = std::make_shared(entry.second); - client_removed_callback_(client_, entry.first, data); + client_removed_callback_(client_, ClientID::nil(), entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &channel_id, - std::shared_ptr data) { - ClientID client_id = ClientID::from_binary(data->client_id); + const ClientTableDataT &data) { + ClientID client_id = ClientID::from_binary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -42,24 +147,24 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha // If the entry is in the cache, then the notification is new if the client // was alive and is now dead. bool was_inserted = entry->second.is_insertion; - bool is_deleted = !data->is_insertion; + bool is_deleted = !data.is_insertion; is_new = (was_inserted && is_deleted); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. if (!entry->second.is_insertion) { - RAY_CHECK(!data->is_insertion) + RAY_CHECK(!data.is_insertion) << "Notification for addition of a client that was already removed:" << client_id.hex(); } } // Add the notification to our cache. Notifications are idempotent. - client_cache_[client_id] = *data; + client_cache_[client_id] = data; // If the notification is new, call any registered callbacks. if (is_new) { - if (data->is_insertion) { + if (data.is_insertion) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, data); } @@ -72,7 +177,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientID &cha } void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientID &client_id, - std::shared_ptr data) { + const ClientTableDataT &data) { RAY_CHECK(client_id == client_id_) << client_id.hex() << " " << client_id_.hex(); } @@ -87,18 +192,17 @@ Status ClientTable::Connect() { data->is_insertion = true; // Callback for a notification from the client table. auto notification_callback = [this](AsyncGcsClient *client, const ClientID &channel_id, - std::shared_ptr data) { + const ClientTableDataT &data) { return HandleNotification(client, channel_id, data); }; // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const ClientID &id, - std::shared_ptr data) { + const ClientTableDataT &data) { HandleConnected(client, id, data); }; // Callback to add ourselves once we've successfully subscribed. - auto subscription_callback = [this, data, add_callback]( - AsyncGcsClient *c, const ClientID &id, std::shared_ptr d) { + auto subscription_callback = [this, data, add_callback](AsyncGcsClient *c) { // Mark ourselves as deleted if we called Disconnect() since the last // Connect() call. if (disconnected_) { @@ -114,7 +218,7 @@ Status ClientTable::Disconnect() { auto data = std::make_shared(local_client_); data->is_insertion = true; auto add_callback = [this](AsyncGcsClient *client, const ClientID &id, - std::shared_ptr data) { + const ClientTableDataT &data) { HandleConnected(client, id, data); }; RAY_RETURN_NOT_OK(Add(JobID::nil(), client_id_, data, add_callback)); @@ -135,6 +239,9 @@ const ClientTableDataT &ClientTable::GetClient(const ClientID &client_id) { } } +template class Table; +template class Table; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 5ac9ad2651fe..7f38055fee72 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -30,9 +30,14 @@ template class Table { public: using DataT = typename Data::NativeTableType; - using Callback = std::function data)>; + using Callback = + std::function; + /// The callback to call when a lookup fails because there is no entry at the + /// key. using FailureCallback = std::function; + /// The callback to call when a SUBSCRIBE call completes and we are ready to + /// request and receive notifications. + using SubscriptionCallback = std::function; struct CallbackData { ID id; @@ -41,7 +46,7 @@ class Table { FailureCallback failure; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. - Callback subscription_callback; + SubscriptionCallback subscription_callback; Table *table; AsyncGcsClient *client; }; @@ -50,7 +55,8 @@ class Table { : context_(context), client_(client), pubsub_channel_(TablePubsub_NO_PUBLISH), - prefix_(TablePrefix_UNUSED){}; + prefix_(TablePrefix_UNUSED), + subscribe_callback_index_(-1){}; /// Add an entry to the table. /// @@ -61,91 +67,63 @@ class Table { /// GCS. /// \return Status Status Add(const JobID &job_id, const ID &id, std::shared_ptr data, - const Callback &done) { - auto d = std::shared_ptr( - new CallbackData({id, data, done, nullptr, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d](const std::string &data) { - if (d->callback != nullptr) { - (d->callback)(d->client, d->id, d->data); - } - }); - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, data.get())); - RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - callback_index)); - return Status::OK(); - } + const Callback &done); /// Lookup an entry asynchronously. /// /// \param job_id The ID of the job (= driver). /// \param id The ID of the data that is looked up in the GCS. - /// \param lookup Callback that is called after lookup. + /// \param lookup Callback that is called after lookup. If the callback is + /// called with an empty vector, then there was no data at the key. /// \return Status Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup, - const FailureCallback &failure) { - auto d = std::shared_ptr( - new CallbackData({id, nullptr, lookup, failure, nullptr, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d](const std::string &data) { - if (data.empty()) { - if (d->failure != nullptr) { - (d->failure)(d->client, d->id); - } - } else { - auto result = std::make_shared(); - auto root = flatbuffers::GetRoot(data.data()); - root->UnPackTo(result.get()); - if (d->callback != nullptr) { - (d->callback)(d->client, d->id, result); - } - } - }); - std::vector nil; - RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), - prefix_, pubsub_channel_, callback_index)); - return Status::OK(); - } + const FailureCallback &failure); - /// Subscribe to updates of this table + /// Subscribe to any Add operations to this table. The caller may choose to + /// subscribe to all Adds, or to subscribe only to keys that it requests + /// notifications for. This may only be called once per Table instance. /// /// \param job_id The ID of the job (= driver). /// \param client_id The type of update to listen to. If this is nil, then a /// message for each Add to the table will be received. Else, only - /// messages for the given client will be received. - /// \param subscribe Callback that is called on each received message. + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe Callback that is called on each received message. If the + /// callback is called with an empty vector, then there was no data at + /// the key. /// \param done Callback that is called when subscription is complete and we - /// are ready to receive messages.. + /// are ready to receive messages. /// \return Status Status Subscribe(const JobID &job_id, const ClientID &client_id, - const Callback &subscribe, const Callback &done) { - auto d = std::shared_ptr( - new CallbackData({client_id, nullptr, subscribe, nullptr, done, this, client_})); - int64_t callback_index = - RedisCallbackManager::instance().add([d](const std::string &data) { - if (data.empty()) { - // No data is provided. This is the callback for the initial - // subscription request. - if (d->subscription_callback != nullptr) { - (d->subscription_callback)(d->client, d->id, nullptr); - } - } else { - // Data is provided. This is the callback for a message. - auto result = std::make_shared(); - auto root = flatbuffers::GetRoot(data.data()); - root->UnPackTo(result.get()); - (d->callback)(d->client, d->id, result); - } - }); - std::vector nil; - return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index); - } + const Callback &subscribe, const SubscriptionCallback &done); + + /// Request notifications about a key in this table. + /// + /// The notifications will be returned via the subscribe callback that was + /// registered by `Subscribe`. An initial notification will be returned for + /// the current value(s) at the key, if any, and a subsequent notification + /// will be published for every following `Add` to the key. Before + /// notifications can be requested, the caller must first call `Subscribe`, + /// with the same `client_id`. + /// + /// \param job_id The ID of the job (= driver). + /// \param id The ID of the key to request notifications for. + /// \param client_id The client who is requesting notifications. Before + /// notifications can be requested, a call to `Subscribe` to this + /// table with the same `client_id` must complete successfully. + /// \return Status + Status RequestNotifications(const JobID &job_id, const ID &id, + const ClientID &client_id); - /// Remove and entry from the table - Status Remove(const JobID &job_id, const ID &id, const Callback &done); + /// Cancel notifications about a key in this table. + /// + /// \param job_id The ID of the job (= driver). + /// \param id The ID of the key to request notifications for. + /// \param client_id The client who originally requested notifications. + /// \return Status + Status CancelNotifications(const JobID &job_id, const ID &id, + const ClientID &client_id); protected: /// The connection to the GCS. @@ -158,6 +136,10 @@ class Table { TablePubsub pubsub_channel_; /// The prefix to use for keys in this table. TablePrefix prefix_; + /// The index in the RedisCallbackManager for the callback that is called + /// when we receive notifications. This is >= 0 iff we have subscribed to the + /// table, otherwise -1. + int64_t subscribe_callback_index_; }; class ObjectTable : public Table { @@ -167,31 +149,6 @@ class ObjectTable : public Table { pubsub_channel_ = TablePubsub_OBJECT; prefix_ = TablePrefix_OBJECT; }; - - /// Set up a client-specific channel for receiving notifications about - /// available - /// objects from the object table. The callback will be called once per - /// notification received on this channel. - /// - /// \param subscribe_all - /// \param object_available_callback Callback to be called when new object - /// becomes available. - /// \param done_callback Callback to be called when subscription is installed. - /// This is only used for the tests. - /// \return Status - Status SubscribeToNotifications(const JobID &job_id, bool subscribe_all, - const Callback &object_available, const Callback &done); - - /// Request notifications about the availability of some objects from the - /// object - /// table. The notifications will be published to this client's object - /// notification channel, which was set up by the method - /// ObjectTableSubscribeToNotifications. - /// - /// \param object_ids The object IDs to receive notifications about. - /// \return Status - Status RequestNotifications(const JobID &job_id, - const std::vector &object_ids); }; class FunctionTable : public Table { @@ -240,11 +197,13 @@ class TaskTable : public Table { std::shared_ptr data, const TestAndUpdateCallback &callback) { int64_t callback_index = RedisCallbackManager::instance().add( - [this, callback, id](const std::string &data) { + [this, callback, id](const std::vector &data) { + RAY_CHECK(data.size() == 1); auto result = std::make_shared(); - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data[0].data()); root->UnPackTo(result.get()); callback(client_, id, *result, root->updated()); + return true; }); flatbuffers::FlatBufferBuilder fbb; fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get())); @@ -293,6 +252,8 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id, class ClientTable : private Table { public: + using ClientTableCallback = std::function; ClientTable(const std::shared_ptr &context, AsyncGcsClient *client, const ClientTableDataT &local_client) : Table(context, client), @@ -324,12 +285,12 @@ class ClientTable : private Table { /// Register a callback to call when a new client is added. /// /// \param callback The callback to register. - void RegisterClientAddedCallback(const Callback &callback); + void RegisterClientAddedCallback(const ClientTableCallback &callback); /// Register a callback to call when a client is removed. /// /// \param callback The callback to register. - void RegisterClientRemovedCallback(const Callback &callback); + void RegisterClientRemovedCallback(const ClientTableCallback &callback); /// Get a client's information from the cache. The cache only contains /// information for clients that we've heard a notification for. @@ -352,10 +313,10 @@ class ClientTable : private Table { private: /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, const ClientID &channel_id, - std::shared_ptr); + const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. void HandleConnected(AsyncGcsClient *client, const ClientID &client_id, - std::shared_ptr); + const ClientTableDataT ¬ifications); /// Whether this client has called Disconnect(). bool disconnected_; @@ -364,9 +325,9 @@ class ClientTable : private Table { /// Information about this client. ClientTableDataT local_client_; /// The callback to call when a new client is added. - Callback client_added_callback_; + ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. - Callback client_removed_callback_; + ClientTableCallback client_removed_callback_; /// A cache for information about all clients. std::unordered_map client_cache_; }; diff --git a/src/ray/gcs/task_table.cc b/src/ray/gcs/task_table.cc index a60ab148e732..ba36d0cd10e4 100644 --- a/src/ray/gcs/task_table.cc +++ b/src/ray/gcs/task_table.cc @@ -44,9 +44,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) { TaskSpec *spec = execution_spec.Spec(); auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task), static_cast(Task_state(task))); - return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data, - [](gcs::AsyncGcsClient *client, const TaskID &id, - std::shared_ptr data) {}); + return gcs_client->task_table().Add( + ray::JobID::nil(), TaskSpec_task_id(spec), data, + [](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {}); } // TODO(pcm): This is a helper method that should go away once we get rid of