Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] bind delete_keys parameter on tf_client update_priority #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions reverb/cc/ops/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ REGISTER_OP("ReverbClientUpdatePriorities")
.Input("table: string")
.Input("keys: uint64")
.Input("priorities: double")
.Input("keys_to_delete: uint64")
.Doc(R"doc(
Blocking call to update the priorities of a collection of items. Keys that could
not be found in table `table` on server are ignored and does not impact the rest
Expand Down Expand Up @@ -187,7 +188,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
const tensorflow::Tensor* keys;
OP_REQUIRES_OK(context, context->input("keys", &keys));
const tensorflow::Tensor* priorities;
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
OP_REQUIRES_OK(context, context->input("priorities", &priorities));
const tensorflow::Tensor* keys_to_delete;
OP_REQUIRES_OK(context, context->input("keys_to_delete", &keys_to_delete));

OP_REQUIRES(
context, keys->dims() == 1,
Expand All @@ -197,6 +200,9 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
"Tensors `keys` and `priorities` do not match in shape (",
keys->shape().DebugString(), " vs. ",
priorities->shape().DebugString(), ")"));
OP_REQUIRES(
context, keys_to_delete->dims() == 1,
InvalidArgument("Tensors `keys_to_delete` must be of rank 1."));

std::string table_str = table->scalar<tstring>()();
std::vector<KeyWithPriority> updates;
Expand All @@ -207,14 +213,19 @@ class UpdatePrioritiesOp : public tensorflow::OpKernel {
updates.push_back(std::move(update));
}

std::vector<uint64_t> deletes;
for (int i = 0; i < keys_to_delete->dim_size(0); i++) {
deletes.push_back(keys_to_delete->flat<tensorflow::uint64>()(i));
}

// The call will only fail if the Reverb-server is brought down during an
// active call (e.g preempted). When this happens the request is retried and
// since MutatePriorities sets `wait_for_ready` the request will no be sent
// before the server is brought up again. It is therefore no problem to have
// this retry in this tight loop.
absl::Status status;
do {
status = resource->client()->MutatePriorities(table_str, updates, {});
status = resource->client()->MutatePriorities(table_str, updates, deletes);
} while (absl::IsUnavailable(status) || absl::IsDeadlineExceeded(status));
OP_REQUIRES_OK(context, ToTensorflowStatus(status));
}
Expand Down
9 changes: 7 additions & 2 deletions reverb/tf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def update_priorities(self,
table: str,
keys: tf.Tensor,
priorities: tf.Tensor,
keys_to_delete: Optional[tf.Tensor] = None,
name: Optional[str] = None):
"""Creates op for updating priorities of existing items in the replay.

Expand All @@ -126,16 +127,20 @@ def update_priorities(self,
table: Probability table to update.
keys: Keys of the items to update. Must be same length as `priorities`.
priorities: New priorities for `keys`. Must be same length as `keys`.
keys_to_delete: Keys of the items to delete
name: Optional name for the operation.

Returns:
A tf-op for performing the update.
"""

if keys_to_delete is None:
keys_to_delete = tf.constant([], dtype=tf.uint64)

with tf.name_scope(name, f'{self._name}_update_priorities',
['update_priorities']) as scope:
return gen_reverb_ops.reverb_client_update_priorities(
self._handle, table, keys, priorities, name=scope)
return gen_client_ops.reverb_client_update_priorities(
self._handle, table, keys, priorities, keys_to_delete, name=scope)

def dataset(self,
table: str,
Expand Down
33 changes: 33 additions & 0 deletions reverb/tf_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,39 @@ def test_priority_update_is_applied(self):
self.fail('Updated item was not found')


def test_delete_key_is_applied(self):
# Start with 4 items
for i in range(4):
self._client.insert([np.array([i], dtype=np.uint32)], {'dist': 1})

# Until we have recieved all 4 items.
items = {}
while len(items) < 4:
item = next(self._client.sample('dist'))[0]
items[item.info.key] = item.info.probability

# remove 2 items
items_to_keep = [*items.keys()][:2]
items_to_remove = [*items.keys()][2:]
with self.session() as session:
client = tf_client.TFClient(self._client.server_address)
for key in items_to_remove:
update_op = client.update_priorities(
table=tf.constant('dist'),
keys=tf.constant([], dtype=tf.uint64),
priorities=tf.constant([], dtype=tf.float64),
keys_to_delete=tf.constant([key], dtype=tf.uint64))
self.assertIsNone(session.run(update_op))

# 2 remaining items must persist
final_items = {}
for _ in range(1000):
item = next(self._client.sample('dist'))[0]
self.assertTrue(item.info.key in items_to_keep)
final_items[item.info.key] = item.info.probability
self.assertEqual(len(final_items), 2)


class InsertOpTest(tf.test.TestCase):

@classmethod
Expand Down