Skip to content

Commit

Permalink
Create class to manage tf.data service dispatcher snapshots.
Browse files Browse the repository at this point in the history
This is mostly a 1:1 restructuring with the following changes:
1) Added simple snapshot recovery from on-disk state.
2) Removed all members tracking snapshot, stream, and source completion.  I think these may have been structured incorrectly, and either way they weren't tested or used.  I'll reevaluate when stream completion is implemented.
3) Removed some validations that weren't tested and/or were related to #1.  Will add back after addressing #1.
4) Renamed directory -> path.

PiperOrigin-RevId: 502934739
  • Loading branch information
mpcallanan authored and tensorflower-gardener committed Jan 18, 2023
1 parent 711cdfd commit eb08bba
Show file tree
Hide file tree
Showing 17 changed files with 654 additions and 303 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/data/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,15 @@ cc_library(
":journal_proto_cc",
":split_provider",
":task_remover",
":utils",
":validate_utils",
":worker_cc_grpc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//tensorflow/core/data/service/snapshot:snapshot_manager",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/data/service/dispatcher.proto
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ message SnapshotRequest {
// The dataset to snapshot.
DatasetDef dataset = 1;

// The directory to which to materialize the snapshot.
string directory = 2;
// The path to which to materialize the snapshot.
string path = 2;

// The metadata for the snapshot.
experimental.DistributedSnapshotMetadata metadata = 3;
Expand All @@ -265,8 +265,8 @@ message SnapshotResponse {}

// Next tag: 4
message GetSnapshotSplitRequest {
// The directory in which the snapshot is being materialized.
string directory = 1;
// The base path of the snapshot materialization.
string base_path = 1;

// The index of the snapshot stream from which to get the split.
int64 stream_index = 2;
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/data/service/dispatcher_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id,
}

Status DataServiceDispatcherClient::Snapshot(
const DatasetDef& dataset, const std::string& directory,
const DatasetDef& dataset, const std::string& path,
const experimental::DistributedSnapshotMetadata& metadata) {
TF_RETURN_IF_ERROR(EnsureInitialized());

SnapshotRequest req;
*req.mutable_dataset() = dataset;
req.set_directory(directory);
req.set_path(path);
*req.mutable_metadata() = metadata;

SnapshotResponse resp;
Expand All @@ -134,12 +134,12 @@ Status DataServiceDispatcherClient::Snapshot(
}

Status DataServiceDispatcherClient::GetSnapshotSplit(
const std::string& directory, int64_t stream_index, int64_t source_index,
const std::string& base_path, int64_t stream_index, int64_t source_index,
Tensor& split, bool& end_of_splits) {
TF_RETURN_IF_ERROR(EnsureInitialized());

GetSnapshotSplitRequest req;
req.set_directory(directory);
req.set_base_path(base_path);
req.set_stream_index(stream_index);
req.set_source_index(source_index);

Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/data/service/dispatcher_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
bool& end_of_splits);

// Gets the next split for the specified source of a stream of the snapshot in
// `directory`. If `end_of_splits` returns true, then there are no more splits
// `base_path`. If `end_of_splits` returns true, then there are no more splits
// to be processed for the specified stream source.
Status GetSnapshotSplit(const std::string& directory, int64_t stream_index,
Status GetSnapshotSplit(const std::string& base_path, int64_t stream_index,
int64_t source_index, Tensor& split,
bool& end_of_splits);

// Initiates the process of materializing `dataset`'s output to `directory`.
Status Snapshot(const DatasetDef& dataset, const std::string& directory,
// Initiates the process of materializing `dataset`'s output to `path`.
Status Snapshot(const DatasetDef& dataset, const std::string& path,
const experimental::DistributedSnapshotMetadata& metadata);

// Registers a dataset with the tf.data service, and stores the generated
Expand Down
31 changes: 15 additions & 16 deletions tensorflow/core/data/service/dispatcher_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,44 +139,43 @@ TEST_F(DispatcherClientTest, GetDataServiceConfig) {
EXPECT_EQ(config.deployment_mode(), DEPLOYMENT_MODE_COLOCATED);
}

TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> directories,
TEST_F(DispatcherClientTest, SkeletonWritten) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> paths,
StartDummySnapshots());
for (const auto& directory : directories) {
TF_ASSERT_OK(Env::Default()->FileExists(
io::JoinPath(directory, "snapshot.metadata")));
TF_ASSERT_OK(Env::Default()->FileExists(
io::JoinPath(directory, "dataset_def.proto")));
for (const auto& path : paths) {
TF_ASSERT_OK(Env::Default()->FileExists(CommittedChunksDirectory(path)));
TF_ASSERT_OK(Env::Default()->FileExists(StreamsDirectory(path)));
}
}

TEST_F(DispatcherClientTest, CreateCommittedChunksDirectory) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> directories,
TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> paths,
StartDummySnapshots());
for (const auto& directory : directories) {
for (const auto& path : paths) {
TF_ASSERT_OK(
Env::Default()->FileExists(io::JoinPath(path, "snapshot.metadata")));
TF_ASSERT_OK(
Env::Default()->FileExists(CommittedChunksDirectory(directory)));
Env::Default()->FileExists(io::JoinPath(path, "dataset_def.proto")));
}
}

TEST_F(DispatcherClientTest, SnapshotsInHeartbeat) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> directories,
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> paths,
StartDummySnapshots());
WorkerHeartbeatRequest worker_heartbeat_request;
worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0));
TF_ASSERT_OK_AND_ASSIGN(
WorkerHeartbeatResponse worker_heartbeat_response,
dispatcher_client_->WorkerHeartbeat(worker_heartbeat_request));
ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(),
directories.size());
ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(), paths.size());
for (const auto& snapshot_task : worker_heartbeat_response.snapshot_tasks()) {
ASSERT_TRUE(directories.count(snapshot_task.base_path()));
ASSERT_TRUE(paths.count(snapshot_task.base_path()));
ASSERT_EQ(snapshot_task.stream_index(), 0);
}
}

TEST_F(DispatcherClientTest, GetSnapshotSplit) {
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> directories,
TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set<std::string> paths,
StartDummySnapshots());
WorkerHeartbeatRequest worker_heartbeat_request;
worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0));
Expand Down
Loading

0 comments on commit eb08bba

Please sign in to comment.