Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions src/common/redis_module/ray_redis_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
return RedisModule_ReplyWithError(ctx, (MESSAGE)); \
}

static const char *table_prefixes[] = {
NULL, "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
};

// TODO(swang): This helper function should be deprecated by the version below,
// which uses enums for table prefixes.
RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
const char *prefix,
RedisModuleString *keyname,
Expand All @@ -61,6 +67,22 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
return key;
}

RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
RedisModuleString *prefix_enum,
RedisModuleString *keyname,
int mode) {
long long prefix_long;
RAY_CHECK(RedisModule_StringToLongLong(prefix_enum, &prefix_long) ==
REDISMODULE_OK)
<< "Prefix must be a valid TablePrefix";
auto prefix = static_cast<TablePrefix>(prefix_long);
RAY_CHECK(prefix != TablePrefix_UNUSED)
<< "This table has no prefix registered";
RAY_CHECK(prefix >= TablePrefix_MIN && prefix <= TablePrefix_MAX)
<< "Prefix must be a valid TablePrefix";
return OpenPrefixedKey(ctx, table_prefixes[prefix], keyname, mode);
}

/**
* This is a helper method to convert a redis module string to a flatbuffer
* string.
Expand Down Expand Up @@ -394,17 +416,18 @@ bool PublishObjectNotification(RedisModuleCtx *ctx,
int TableAdd_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 4) {
if (argc != 5) {
return RedisModule_WrongArity(ctx);
}

RedisModuleString *pubsub_channel_str = argv[1];
RedisModuleString *id = argv[2];
RedisModuleString *data = argv[3];
RedisModuleString *prefix_str = argv[1];
RedisModuleString *pubsub_channel_str = argv[2];
RedisModuleString *id = argv[3];
RedisModuleString *data = argv[4];

// Set the keys in the table.
RedisModuleKey *key =
OpenPrefixedKey(ctx, "T:", id, REDISMODULE_READ | REDISMODULE_WRITE);
RedisModuleKey *key = OpenPrefixedKey(ctx, prefix_str, id,
REDISMODULE_READ | REDISMODULE_WRITE);
RedisModule_StringSet(key, data);
RedisModule_CloseKey(key);

Expand Down Expand Up @@ -517,13 +540,14 @@ int TableAdd_RedisCommand(RedisModuleCtx *ctx,
int TableLookup_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 3) {
if (argc != 4) {
return RedisModule_WrongArity(ctx);
}

RedisModuleString *id = argv[2];
RedisModuleString *prefix_str = argv[1];
RedisModuleString *id = argv[3];

RedisModuleKey *key = OpenPrefixedKey(ctx, "T:", id, REDISMODULE_READ);
RedisModuleKey *key = OpenPrefixedKey(ctx, prefix_str, id, REDISMODULE_READ);
if (key == nullptr) {
return RedisModule_ReplyWithNull(ctx);
}
Expand Down Expand Up @@ -554,14 +578,15 @@ bool is_nil(const std::string &data) {
int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
int argc) {
if (argc != 4) {
if (argc != 5) {
return RedisModule_WrongArity(ctx);
}
RedisModuleString *id = argv[2];
RedisModuleString *update_data = argv[3];
RedisModuleString *prefix_str = argv[1];
RedisModuleString *id = argv[3];
RedisModuleString *update_data = argv[4];

RedisModuleKey *key =
OpenPrefixedKey(ctx, "T:", id, REDISMODULE_READ | REDISMODULE_WRITE);
RedisModuleKey *key = OpenPrefixedKey(ctx, prefix_str, id,
REDISMODULE_READ | REDISMODULE_WRITE);

size_t value_len = 0;
char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ);
Expand Down
8 changes: 8 additions & 0 deletions src/ray/gcs/format/gcs.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ enum Language:int {
JAVA = 2
}

enum TablePrefix:int {
UNUSED = 0,
TASK,
CLIENT,
OBJECT,
FUNCTION
}

// The channel that Add operations to the Table should be published on, if any.
enum TablePubsub:int {
NO_PUBLISH = 0,
Expand Down
14 changes: 7 additions & 7 deletions src/ray/gcs/redis_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,23 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
}

Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
uint8_t *data, int64_t length,
uint8_t *data, int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index) {
if (length > 0) {
std::string redis_command = command + " %d %b %b";
std::string redis_command = command + " %d %d %b %b";
int status = redisAsyncCommand(
async_context_, reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel,
id.data(), id.size(), data, length);
reinterpret_cast<void *>(callback_index), redis_command.c_str(), prefix,
pubsub_channel, id.data(), id.size(), data, length);
if (status == REDIS_ERR) {
return Status::RedisError(std::string(async_context_->errstr));
}
} else {
std::string redis_command = command + " %d %b";
std::string redis_command = command + " %d %d %b";
int status = redisAsyncCommand(
async_context_, reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel,
id.data(), id.size());
reinterpret_cast<void *>(callback_index), redis_command.c_str(), prefix,
pubsub_channel, id.data(), id.size());
if (status == REDIS_ERR) {
return Status::RedisError(std::string(async_context_->errstr));
}
Expand Down
4 changes: 2 additions & 2 deletions src/ray/gcs/redis_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class RedisContext {
Status Connect(const std::string &address, int port);
Status AttachToEventLoop(aeEventLoop *loop);
Status RunAsync(const std::string &command, const UniqueID &id, uint8_t *data,
int64_t length, const TablePubsub pubsub_channel,
int64_t callback_index);
int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index);
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
int64_t callback_index);
redisAsyncContext *async_context() { return async_context_; }
Expand Down
32 changes: 26 additions & 6 deletions src/ray/gcs/tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class Table {
};

Table(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: context_(context), client_(client), pubsub_channel_(TablePubsub_NO_PUBLISH){};
: context_(context),
client_(client),
pubsub_channel_(TablePubsub_NO_PUBLISH),
prefix_(TablePrefix_UNUSED){};

/// Add an entry to the table.
///
Expand All @@ -71,7 +74,8 @@ class Table {
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(),
fbb.GetSize(), pubsub_channel_, callback_index));
fbb.GetSize(), prefix_, pubsub_channel_,
callback_index));
return Status::OK();
}

Expand Down Expand Up @@ -102,7 +106,7 @@ class Table {
});
std::vector<uint8_t> nil;
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
pubsub_channel_, callback_index));
prefix_, pubsub_channel_, callback_index));
return Status::OK();
}

Expand Down Expand Up @@ -144,17 +148,24 @@ class Table {
Status Remove(const JobID &job_id, const ID &id, const Callback &done);

protected:
std::unordered_map<ID, std::unique_ptr<CallbackData>, UniqueIDHasher> callback_data_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch :)

/// The connection to the GCS.
std::shared_ptr<RedisContext> context_;
/// The GCS client.
AsyncGcsClient *client_;
/// The pubsub channel to subscribe to for notifications about keys in this
/// table. If no notifications are required, this may be set to
/// TablePubsub_NO_PUBLISH.
TablePubsub pubsub_channel_;
/// The prefix to use for keys in this table.
TablePrefix prefix_;
};

class ObjectTable : public Table<ObjectID, ObjectTableData> {
public:
ObjectTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
pubsub_channel_ = TablePubsub_OBJECT;
prefix_ = TablePrefix_OBJECT;
};

/// Set up a client-specific channel for receiving notifications about
Expand Down Expand Up @@ -183,7 +194,14 @@ class ObjectTable : public Table<ObjectID, ObjectTableData> {
const std::vector<ObjectID> &object_ids);
};

using FunctionTable = Table<FunctionID, FunctionTableData>;
class FunctionTable : public Table<ObjectID, FunctionTableData> {
public:
FunctionTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
pubsub_channel_ = TablePubsub_NO_PUBLISH;
prefix_ = TablePrefix_FUNCTION;
};
};

using ClassTable = Table<ClassID, ClassTableData>;

Expand All @@ -195,6 +213,7 @@ class TaskTable : public Table<TaskID, TaskTableData> {
TaskTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
pubsub_channel_ = TablePubsub_TASK;
prefix_ = TablePrefix_TASK;
};

using TestAndUpdateCallback =
Expand Down Expand Up @@ -230,7 +249,7 @@ class TaskTable : public Table<TaskID, TaskTableData> {
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id,
fbb.GetBufferPointer(), fbb.GetSize(),
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, callback_index));
return Status::OK();
}
Expand Down Expand Up @@ -281,6 +300,7 @@ class ClientTable : private Table<ClientID, ClientTableData> {
client_id_(ClientID::from_binary(local_client.client_id)),
local_client_(local_client) {
pubsub_channel_ = TablePubsub_CLIENT;
prefix_ = TablePrefix_CLIENT;

// Add a nil client to the cache so that we can serve requests for clients
// that we have not heard about.
Expand Down