From eb08bbae62fb689843befe858df12db17425d8e1 Mon Sep 17 00:00:00 2001 From: Matt Callanan Date: Wed, 18 Jan 2023 11:36:16 -0800 Subject: [PATCH] Create class to manage tf.data service dispatcher snapshots. This is mostly a 1:1 restructuring with the following changes: 1) Added simple snapshot recovery from on-disk state. 2) Removed all members tracking snapshot, stream, and source completion. I think these may have been structured incorrectly, and either way they weren't tested or used. I'll reevaluate when stream completion is implemented. 3) Removed some validations that weren't tested and/or were related to #1. Will add back after addressing #1. 4) Renamed directory -> path. PiperOrigin-RevId: 502934739 --- tensorflow/core/data/service/BUILD | 3 + tensorflow/core/data/service/dispatcher.proto | 8 +- .../core/data/service/dispatcher_client.cc | 8 +- .../core/data/service/dispatcher_client.h | 8 +- .../data/service/dispatcher_client_test.cc | 31 +- .../core/data/service/dispatcher_impl.cc | 187 ++--------- .../core/data/service/dispatcher_impl.h | 98 +----- .../core/data/service/dispatcher_state.cc | 2 +- .../core/data/service/dispatcher_state.h | 11 +- .../data/service/dispatcher_state_test.cc | 14 +- tensorflow/core/data/service/journal.proto | 4 +- tensorflow/core/data/service/snapshot/BUILD | 21 ++ .../data/service/snapshot/snapshot_manager.cc | 295 ++++++++++++++++++ .../data/service/snapshot/snapshot_manager.h | 129 ++++++++ tensorflow/core/data/service/worker_impl.cc | 2 +- .../kernel_tests/distributed_save_test.py | 14 +- .../service/fault_tolerance_test.py | 122 ++++++++ 17 files changed, 654 insertions(+), 303 deletions(-) create mode 100644 tensorflow/core/data/service/snapshot/snapshot_manager.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_manager.h diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 193a0e3add43a2..06290f1d96b39e 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -394,12 +394,15 @@ cc_library( ":journal_proto_cc", ":split_provider", ":task_remover", + ":utils", ":validate_utils", ":worker_cc_grpc_proto", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "//tensorflow/core/data/service/snapshot:snapshot_manager", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/core/data/service/dispatcher.proto b/tensorflow/core/data/service/dispatcher.proto index e0d006b58ed507..0e5f56414140f2 100644 --- a/tensorflow/core/data/service/dispatcher.proto +++ b/tensorflow/core/data/service/dispatcher.proto @@ -253,8 +253,8 @@ message SnapshotRequest { // The dataset to snapshot. DatasetDef dataset = 1; - // The directory to which to materialize the snapshot. - string directory = 2; + // The path to which to materialize the snapshot. + string path = 2; // The metadata for the snapshot. experimental.DistributedSnapshotMetadata metadata = 3; @@ -265,8 +265,8 @@ message SnapshotResponse {} // Next tag: 4 message GetSnapshotSplitRequest { - // The directory in which the snapshot is being materialized. - string directory = 1; + // The base path of the snapshot materialization. + string base_path = 1; // The index of the snapshot stream from which to get the split. int64 stream_index = 2; diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index 99e59393011947..daba27c69fb2fc 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -115,13 +115,13 @@ Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, } Status DataServiceDispatcherClient::Snapshot( - const DatasetDef& dataset, const std::string& directory, + const DatasetDef& dataset, const std::string& path, const experimental::DistributedSnapshotMetadata& metadata) { TF_RETURN_IF_ERROR(EnsureInitialized()); SnapshotRequest req; *req.mutable_dataset() = dataset; - req.set_directory(directory); + req.set_path(path); *req.mutable_metadata() = metadata; SnapshotResponse resp; @@ -134,12 +134,12 @@ Status DataServiceDispatcherClient::Snapshot( } Status DataServiceDispatcherClient::GetSnapshotSplit( - const std::string& directory, int64_t stream_index, int64_t source_index, + const std::string& base_path, int64_t stream_index, int64_t source_index, Tensor& split, bool& end_of_splits) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetSnapshotSplitRequest req; - req.set_directory(directory); + req.set_base_path(base_path); req.set_stream_index(stream_index); req.set_source_index(source_index); diff --git a/tensorflow/core/data/service/dispatcher_client.h b/tensorflow/core/data/service/dispatcher_client.h index 0f5fc71c2d5d4d..ab442d7d67466e 100644 --- a/tensorflow/core/data/service/dispatcher_client.h +++ b/tensorflow/core/data/service/dispatcher_client.h @@ -65,14 +65,14 @@ class DataServiceDispatcherClient : public DataServiceClientBase { bool& end_of_splits); // Gets the next split for the specified source of a stream of the snapshot in - // `directory`. If `end_of_splits` returns true, then there are no more splits + // `base_path`. If `end_of_splits` returns true, then there are no more splits // to be processed for the specified stream source. - Status GetSnapshotSplit(const std::string& directory, int64_t stream_index, + Status GetSnapshotSplit(const std::string& base_path, int64_t stream_index, int64_t source_index, Tensor& split, bool& end_of_splits); - // Initiates the process of materializing `dataset`'s output to `directory`. - Status Snapshot(const DatasetDef& dataset, const std::string& directory, + // Initiates the process of materializing `dataset`'s output to `path`. + Status Snapshot(const DatasetDef& dataset, const std::string& path, const experimental::DistributedSnapshotMetadata& metadata); // Registers a dataset with the tf.data service, and stores the generated diff --git a/tensorflow/core/data/service/dispatcher_client_test.cc b/tensorflow/core/data/service/dispatcher_client_test.cc index 64ad022713b90c..69d9026bf5ddd1 100644 --- a/tensorflow/core/data/service/dispatcher_client_test.cc +++ b/tensorflow/core/data/service/dispatcher_client_test.cc @@ -139,44 +139,43 @@ TEST_F(DispatcherClientTest, GetDataServiceConfig) { EXPECT_EQ(config.deployment_mode(), DEPLOYMENT_MODE_COLOCATED); } -TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) { - TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set directories, +TEST_F(DispatcherClientTest, SkeletonWritten) { + TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, StartDummySnapshots()); - for (const auto& directory : directories) { - TF_ASSERT_OK(Env::Default()->FileExists( - io::JoinPath(directory, "snapshot.metadata"))); - TF_ASSERT_OK(Env::Default()->FileExists( - io::JoinPath(directory, "dataset_def.proto"))); + for (const auto& path : paths) { + TF_ASSERT_OK(Env::Default()->FileExists(CommittedChunksDirectory(path))); + TF_ASSERT_OK(Env::Default()->FileExists(StreamsDirectory(path))); } } -TEST_F(DispatcherClientTest, CreateCommittedChunksDirectory) { - TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set directories, +TEST_F(DispatcherClientTest, SnapshotMetadataAndDatasetDefWritten) { + TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, StartDummySnapshots()); - for (const auto& directory : directories) { + for (const auto& path : paths) { + TF_ASSERT_OK( + Env::Default()->FileExists(io::JoinPath(path, "snapshot.metadata"))); TF_ASSERT_OK( - Env::Default()->FileExists(CommittedChunksDirectory(directory))); + Env::Default()->FileExists(io::JoinPath(path, "dataset_def.proto"))); } } TEST_F(DispatcherClientTest, SnapshotsInHeartbeat) { - TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set directories, + TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, StartDummySnapshots()); WorkerHeartbeatRequest worker_heartbeat_request; worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0)); TF_ASSERT_OK_AND_ASSIGN( WorkerHeartbeatResponse worker_heartbeat_response, dispatcher_client_->WorkerHeartbeat(worker_heartbeat_request)); - ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(), - directories.size()); + ASSERT_EQ(worker_heartbeat_response.snapshot_tasks_size(), paths.size()); for (const auto& snapshot_task : worker_heartbeat_response.snapshot_tasks()) { - ASSERT_TRUE(directories.count(snapshot_task.base_path())); + ASSERT_TRUE(paths.count(snapshot_task.base_path())); ASSERT_EQ(snapshot_task.stream_index(), 0); } } TEST_F(DispatcherClientTest, GetSnapshotSplit) { - TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set directories, + TF_ASSERT_OK_AND_ASSIGN(absl::flat_hash_set paths, StartDummySnapshots()); WorkerHeartbeatRequest worker_heartbeat_request; worker_heartbeat_request.set_worker_address(test_cluster_->WorkerAddress(0)); diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index a1d1a8e8cece14..67d4e0d887cc3c 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -32,6 +32,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/time/time.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/hash_utils.h" @@ -47,6 +49,7 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/split_provider.h" +#include "tensorflow/core/data/service/utils.h" #include "tensorflow/core/data/service/validate_utils.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/snapshot_utils.h" @@ -188,7 +191,6 @@ DataServiceDispatcherImpl::~DataServiceDispatcherImpl() { maintenance_thread_.reset(); } -// TODO(b/250921378): Recover snapshots. Status DataServiceDispatcherImpl::Start() { mutex_lock l(mu_); if (config_.job_gc_timeout_ms() >= 0) { @@ -243,6 +245,13 @@ Status DataServiceDispatcherImpl::Start() { // Initialize the journal writer in `Start` so that we fail fast in case it // can't be initialized. TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized()); + + for (const auto& path : state_.ListSnapshotPaths()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot_manager, + SnapshotManager::Resume(path, env_)); + snapshots_.insert({path, std::move(snapshot_manager)}); + } + started_ = true; return OkStatus(); } @@ -331,50 +340,6 @@ Status DataServiceDispatcherImpl::FindNewTasks( return OkStatus(); } -Status DataServiceDispatcherImpl::CreateSnapshotStream( - absl::string_view snapshot_directory, absl::string_view worker_address, - SnapshotState& snapshot_state) { - for (int64_t source_index = 0; - source_index < snapshot_state.split_providers.size(); ++source_index) { - TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir(SourceDirectory( - snapshot_directory, snapshot_state.streams.size(), source_index))); - } - snapshot_state.streams.push_back( - StreamState(snapshot_state.split_providers.size(), worker_address)); - return OkStatus(); -} - -Status DataServiceDispatcherImpl::PopulateSnapshotInfo( - absl::string_view worker_address, WorkerHeartbeatResponse* response) { - for (auto& [snapshot_directory, snapshot_state] : snapshots_) { - auto it = snapshot_state.assigned_streams.find(worker_address); - if (it == snapshot_state.assigned_streams.end() && - snapshot_state.mode != SnapshotState::Mode::kActive) { - // If new workers are starting but the snapshot is not active, do not add - // a snapshot task. - continue; - } - - SnapshotTaskDef* snapshot_task = response->add_snapshot_tasks(); - snapshot_task->set_base_path(snapshot_directory); - snapshot_task->set_num_sources(snapshot_state.split_providers.size()); - if (it != snapshot_state.assigned_streams.end()) { - snapshot_task->set_stream_index(it->second); - continue; - } - - // TODO(mpcallanan): Handle orphaned streams. - TF_RETURN_IF_ERROR(CreateSnapshotStream(snapshot_directory, worker_address, - snapshot_state)); - snapshot_task->set_stream_index(snapshot_state.streams.size() - 1); - snapshot_state.assigned_streams[worker_address] = - snapshot_task->stream_index(); - VLOG(1) << "creating stream #" << snapshot_task->stream_index() - << " and assigning to worker " << worker_address; - } - return OkStatus(); -} - Status DataServiceDispatcherImpl::WorkerHeartbeat( const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -409,7 +374,10 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( FindTasksToDelete(current_tasks, assigned_tasks, response)); TF_RETURN_IF_ERROR( FindNewTasks(worker_address, current_tasks, assigned_tasks, response)); - TF_RETURN_IF_ERROR(PopulateSnapshotInfo(worker_address, response)); + + for (const auto& [path, snapshot_manager] : snapshots_) { + TF_RETURN_IF_ERROR(snapshot_manager->WorkerHeartbeat(*request, *response)); + } VLOG(4) << "Finished worker heartbeat for worker at address " << request->worker_address(); @@ -1082,83 +1050,25 @@ Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, return OkStatus(); } -StatusOr DataServiceDispatcherImpl::CreateSnapshotState( - const std::string& snapshot_directory, const DatasetDef& dataset_def) { - auto [it, ignore] = snapshots_.insert({snapshot_directory, SnapshotState()}); - TF_RETURN_IF_ERROR( - CreateSplitProviders(dataset_def, it->second.split_providers)); - return &it->second; -} - Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, SnapshotResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); - if (snapshots_.contains(request->directory())) { - return errors::InvalidArgument("a snapshot at \"", request->directory(), - "\" is already started or completed"); + if (snapshots_.contains(request->path())) { + return errors::InvalidArgument("a snapshot at ", request->path(), + " is already started or completed"); } - TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile( - env_, request->directory(), &request->metadata())); - TF_RETURN_IF_ERROR(WriteTextProto( - env_, DatasetDefFilePath(request->directory()), request->dataset())); - TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir( - CommittedChunksDirectory(request->directory()))); + TF_ASSIGN_OR_RETURN(std::unique_ptr snapshot_manager, + SnapshotManager::Start(*request, env_)); + snapshots_.insert({request->path(), std::move(snapshot_manager)}); Update update; SnapshotUpdate* snapshot = update.mutable_snapshot(); - snapshot->set_directory(request->directory()); + snapshot->set_path(request->path()); TF_RETURN_IF_ERROR(Apply(update)); - TF_RETURN_IF_ERROR( - CreateSnapshotState(request->directory(), request->dataset()).status()); - - return OkStatus(); -} - -Status DataServiceDispatcherImpl::ValidateGetSnapshotSplitRequest( - const GetSnapshotSplitRequest& request) { - auto snapshot_state_it = snapshots_.find(request.directory()); - if (snapshot_state_it == snapshots_.end()) { - return errors::InvalidArgument( - "the dispatcher does not know of a snapshot at ", request.directory()); - } - SnapshotState& snapshot_state = snapshot_state_it->second; - if (snapshot_state.mode == SnapshotState::Mode::kDone) { - return errors::InvalidArgument( - "the dispatcher considers all splits for the snapshot at ", - request.directory(), "to have already been processed"); - } - - if (request.stream_index() >= snapshot_state.streams.size()) { - return errors::InvalidArgument("the dispatcher does not know of a stream ", - absl::StrCat(request.stream_index()), - " for the snapshot at ", - request.directory()); - } - StreamState& stream_state = snapshot_state.streams[request.stream_index()]; - if (stream_state.mode == StreamState::Mode::kDone) { - return errors::InvalidArgument("the dispatcher considers the stream ", - absl::StrCat(request.stream_index()), - "for the snapshot at ", request.directory(), - " to be done"); - } - - if (request.source_index() >= stream_state.sources.size()) { - return errors::InvalidArgument(absl::StrCat( - "the dispatcher does not know of a dataset source at index ", - request.source_index(), " for the stream at ", request.stream_index(), - " for the snapshot at ", request.directory())); - } - if (stream_state.sources[request.source_index()].done) { - return errors::InvalidArgument(absl::StrCat( - "the dispatcher considers the source at index ", request.source_index(), - " for the stream at ", request.stream_index(), " for the snapshot at ", - request.directory(), " to be done")); - } - return OkStatus(); } @@ -1168,58 +1078,13 @@ Status DataServiceDispatcherImpl::GetSnapshotSplit( TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); - TF_RETURN_IF_ERROR(ValidateGetSnapshotSplitRequest(*request)); - - SnapshotState& snapshot_state = snapshots_[request->directory()]; - if (snapshot_state.mode == SnapshotState::Mode::kWindingDown) { - response->set_end_of_splits(true); - return OkStatus(); - } - - Tensor split; - bool end_of_splits = true; - SplitProvider* split_provider = - snapshot_state.split_providers[request->source_index()].get(); - DCHECK(split_provider != nullptr); - TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits)); - - StreamState& stream_state = snapshot_state.streams[request->stream_index()]; - SourceState& source_state = stream_state.sources[request->source_index()]; - if (end_of_splits) { - source_state.done = true; - stream_state.active_sources.erase(request->source_index()); - if (stream_state.active_sources.empty()) { - stream_state.mode = StreamState::Mode::kDone; - snapshot_state.assigned_streams.erase(stream_state.worker_address); - } - snapshot_state.mode = snapshot_state.assigned_streams.empty() - ? SnapshotState::Mode::kDone - : SnapshotState::Mode::kWindingDown; - - response->set_end_of_splits(true); - return OkStatus(); - } - - std::string unassigned_split_path; - if (!env_->LocalTempFilename(&unassigned_split_path)) { - return errors::Internal("failed to write split"); + auto it = snapshots_.find(request->base_path()); + if (it == snapshots_.end()) { + return errors::InvalidArgument( + "the dispatcher does not know of a snapshot at ", request->base_path()); } - snapshot_util::TFRecordWriter writer(unassigned_split_path, - tsl::io::compression::kNone); - TF_RETURN_IF_ERROR(writer.Initialize(env_)); - TF_RETURN_IF_ERROR(writer.WriteTensors({split})); - - std::string assigned_split_path = - SplitPath(request->directory(), request->stream_index(), - request->source_index(), source_state.next_local_split_index, - snapshot_state.next_global_split_index); - TF_RETURN_IF_ERROR( - env_->RenameFile(unassigned_split_path, assigned_split_path)); - ++source_state.next_local_split_index; - ++snapshot_state.next_global_split_index; - - split.AsProtoTensorContent(response->mutable_split()); + TF_RETURN_IF_ERROR(it->second->GetSnapshotSplit(*request, *response)); return OkStatus(); } diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 17fc167c4f12ac..fd628fee8c8344 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/dispatcher_state.h" #include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/data/service/snapshot/snapshot_manager.h" #include "tensorflow/core/data/service/task_remover.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/framework/dataset.h" @@ -44,78 +45,6 @@ limitations under the License. namespace tensorflow { namespace data { -// Structs for maintaining the in-memory state of `Snapshot`s. This state -// mirrors that which is on-disk. -struct SourceState { - SourceState() : next_local_split_index(0), done(false) {} - // A counter of all assigned splits for the source. - int64_t next_local_split_index; - // If true, there are no more splits to process for the source. - bool done; -}; - -struct StreamState { - enum class Mode { - // A worker is processing the stream and is heartbeating. - kAssigned, - // A worker was processing the stream but has stopped heartbeating. - kOrphan, - // The dispatcher restarted and has yet to get a heartbeat from a worker - // processing the stream. - kUnknown, - // There are no more splits to process for the stream. - kDone, - }; - - explicit StreamState(int64_t num_sources) - : sources(num_sources), mode(Mode::kUnknown), worker_address("") {} - - explicit StreamState(int64_t num_sources, absl::string_view worker_address) - : sources(num_sources), - mode(Mode::kAssigned), - worker_address(worker_address) {} - - // All sources whose splits have been assigned for the stream. - std::vector sources; - // If `kOrphan`, the stream is a candidate to be assigned to an unoccupied - // worker. `kUnknown` transitions to `kAssigned` or `kOrphan` depending on - // whether or not the dispatcher gets a heartbeat from a worker processing the - // stream. - Mode mode; - // If `mode` is `kAssigned`, the address of the worker processing the stream. - std::string worker_address; - // Indices of all unfinished sources. - absl::flat_hash_set active_sources; -}; - -struct SnapshotState { - enum class Mode { - // No streams are done. - kActive, - // Some streams are done, but not all. - kWindingDown, - // All streams are done. - kDone - }; - - SnapshotState() : next_global_split_index(0), mode(Mode::kActive) {} - - // Split providers for each input of the dataset being materialized. - std::vector> split_providers; - // All streams for the snapshot. - std::vector streams; - // Indices of all unfinished streams with a known worker assignment, keyed by - // worker address. - absl::flat_hash_map assigned_streams; - // Indices of all unfinished streams with an unknown worker assignment. - absl::flat_hash_set unassigned_streams; - // A counter of all assigned splits for the snapshot. - int64_t next_global_split_index; - // If not `kActive`, at least one source of one stream has finished processing - // and no new streams are created or assigned. - Mode mode; -}; - // A service which coordinates a pool of workers to serve dataset elements over // RPC. // @@ -329,24 +258,6 @@ class DataServiceDispatcherImpl { std::vector>& tasks) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Creates a `snapshots_` entry for `dataset_def` at `snapshot_directory`. - // Note that this does not read from `snapshot_directory`. - StatusOr CreateSnapshotState( - const std::string& snapshot_directory, const DatasetDef& dataset_def) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Populates `response.snapshots` with information from `snapshots_`. - Status PopulateSnapshotInfo(absl::string_view worker_address, - WorkerHeartbeatResponse* response) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Creates a new snapshot stream, both writing it on-disk to - // `snapshot_directory` and adding an entry in-memory to `snapshots_state`. - Status CreateSnapshotStream(absl::string_view snapshot_directory, - absl::string_view worker_address, - SnapshotState& snapshot_state); - // Validates `request` against `snapshots_`. - Status ValidateGetSnapshotSplitRequest(const GetSnapshotSplitRequest& request) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Creates a new task for an iteration. The created task may be either // pending or active. Status CreateTask(std::shared_ptr iteration, @@ -435,9 +346,10 @@ class DataServiceDispatcherImpl { absl::flat_hash_map latest_client_heartbeats_time_ TF_GUARDED_BY(mu_); - // Map from snapshot directory to state mirroring that of the - // materialization. - absl::flat_hash_map snapshots_ TF_GUARDED_BY(mu_); + // Managers for all snapshot processes created or recovered during the + // lifetime of this dispatcher instance. + absl::flat_hash_map> snapshots_ + TF_GUARDED_BY(mu_); std::optional> journal_writer_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index b45e06f9dd7608..36536416c83d77 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -488,7 +488,7 @@ StatusOr DispatcherState::GetWorkerIndex( } void DispatcherState::Snapshot(const SnapshotUpdate& snapshot) { - snapshot_directories_.insert(snapshot.directory()); + snapshot_paths_.insert(snapshot.path()); } } // namespace data diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index c4070810cd4a34..7ce7f7b304d357 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -294,9 +294,10 @@ class DispatcherState { // deterministically sharding a dataset among a fixed set of workers. StatusOr GetWorkerIndex(absl::string_view worker_address) const; - // Returns the directories of all active or completed snapshots. - const absl::flat_hash_set& ListSnapshotDirectories() const { - return snapshot_directories_; + // Returns the paths of all snapshots inititated during the lifetime of this + // journal. + const absl::flat_hash_set& ListSnapshotPaths() const { + return snapshot_paths_; } private: @@ -362,8 +363,8 @@ class DispatcherState { // Tasks, keyed by worker addresses. The values are a map from task id to // task. absl::flat_hash_map tasks_by_worker_; - // Directories of all active or completed snapshots. - absl::flat_hash_set snapshot_directories_; + // Paths for all snapshots initiated during the lifetime of this journal. + absl::flat_hash_set snapshot_paths_; }; } // namespace data diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index 066875a5922cd1..1f63a00165e2c3 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -144,10 +144,10 @@ Status FinishTask(int64_t task_id, DispatcherState& state) { return state.Apply(update); } -Status Snapshot(const std::string& directory, DispatcherState& state) { +Status Snapshot(const std::string& path, DispatcherState& state) { Update update; SnapshotUpdate* snapshot = update.mutable_snapshot(); - snapshot->set_directory(directory); + snapshot->set_path(path); return state.Apply(update); } @@ -695,13 +695,13 @@ TEST(DispatcherState, ListActiveClients) { EXPECT_THAT(state.ListActiveClientIds(), UnorderedElementsAre(6, 8)); } -TEST(DispatcherState, ListSnapshotDirectories) { +TEST(DispatcherState, ListSnapshotPaths) { DispatcherState state; - absl::flat_hash_set snapshot_directories = {"p1", "p2"}; - for (const auto& snapshot_directory : snapshot_directories) { - TF_EXPECT_OK(Snapshot(snapshot_directory, state)); + absl::flat_hash_set snapshot_paths = {"p1", "p2"}; + for (const auto& snapshot_path : snapshot_paths) { + TF_EXPECT_OK(Snapshot(snapshot_path, state)); } - EXPECT_EQ(state.ListSnapshotDirectories(), snapshot_directories); + EXPECT_EQ(state.ListSnapshotPaths(), snapshot_paths); } } // namespace data diff --git a/tensorflow/core/data/service/journal.proto b/tensorflow/core/data/service/journal.proto index edab45720add7d..1b3a35f43ad733 100644 --- a/tensorflow/core/data/service/journal.proto +++ b/tensorflow/core/data/service/journal.proto @@ -145,7 +145,7 @@ message FinishTaskUpdate { int64 task_id = 1; } -// Next tag: 4 +// Next tag: 2 message SnapshotUpdate { - string directory = 1; + string path = 1; } diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index c97fa510d5a7f3..60f0b19c50a5c8 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -88,6 +88,27 @@ cc_library( ], ) +cc_library( + name = "snapshot_manager", + srcs = ["snapshot_manager.cc"], + hdrs = ["snapshot_manager.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":file_utils", + ":path_utils", + "//tensorflow/core:framework", + "//tensorflow/core/data:snapshot_utils", + "//tensorflow/core/data/service:common_proto_cc", + "//tensorflow/core/data/service:dispatcher_proto_cc", + "//tensorflow/core/data/service:split_provider", + "//tensorflow/core/platform:status", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", + ], +) + cc_library( name = "snapshot_stream_writer", srcs = ["snapshot_stream_writer.cc"], diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc new file mode 100644 index 00000000000000..87b4c0ed055438 --- /dev/null +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -0,0 +1,295 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/data/service/snapshot/snapshot_manager.h" + +#include +#include +#include + +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/snapshot/file_utils.h" +#include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tensorflow/core/data/service/split_provider.h" +#include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace data { + +using ::tsl::OkStatus; +using ::tsl::errors::InvalidArgument; + +StatusOr> SnapshotManager::Start( + const SnapshotRequest& request, Env* env) { + SnapshotManager* snapshot_manager = new SnapshotManager(request.path(), env); + TF_RETURN_IF_ERROR(snapshot_manager->Start(request)); + return absl::WrapUnique(snapshot_manager); +} + +Status SnapshotManager::Start(const SnapshotRequest& request) { + if (env_->FileExists(request.path()).ok()) { + return InvalidArgument(request.path(), " already exists"); + } + TF_RETURN_IF_ERROR(CreateSplitProviders(request.dataset(), split_providers_)); + TF_RETURN_IF_ERROR(WriteOnDiskSkeleton()); + TF_RETURN_IF_ERROR(WriteOnDiskMetadata(request)); + return OkStatus(); +} + +Status SnapshotManager::WriteOnDiskSkeleton() { + TF_RETURN_IF_ERROR( + env_->RecursivelyCreateDir(CommittedChunksDirectory(path_))); + TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir(StreamsDirectory(path_))); + return OkStatus(); +} + +Status SnapshotManager::WriteOnDiskMetadata(const SnapshotRequest& request) { + TF_RETURN_IF_ERROR(WriteTextProto(env_, SnapshotMetadataFilePath(path_), + request.metadata())); + TF_RETURN_IF_ERROR( + WriteBinaryProto(env_, DatasetDefFilePath(path_), request.dataset())); + return OkStatus(); +} + +StatusOr> SnapshotManager::Resume( + absl::string_view path, Env* env) { + SnapshotManager* snapshot_manager = new SnapshotManager(path, env); + TF_RETURN_IF_ERROR(snapshot_manager->Resume()); + return absl::WrapUnique(snapshot_manager); +} + +Status SnapshotManager::Resume() { + if (!env_->FileExists(path_).ok()) { + return InvalidArgument("failed to recover snapshot at ", path_, + ": the snapshot path doesn't exist"); + } + TF_RETURN_IF_ERROR(ReadOnDiskMetadata()); + TF_RETURN_IF_ERROR(ReadOnDiskStreams()); + return OkStatus(); +} + +Status SnapshotManager::ReadOnDiskMetadata() { + if (!env_->FileExists(SnapshotMetadataFilePath(path_)).ok()) { + return InvalidArgument("failed to recover snapshot at ", path_, + ": snapshot has no snapshot.metadata"); + } + experimental::DistributedSnapshotMetadata metadata; + TF_RETURN_IF_ERROR( + ReadTextProto(env_, SnapshotMetadataFilePath(path_), &metadata)); + + if (!env_->FileExists(DatasetDefFilePath(path_)).ok()) { + return InvalidArgument("failed to recovery snapshot at ", path_, + ": snapshot has no dataset_def.proto"); + } + DatasetDef dataset_def; + TF_RETURN_IF_ERROR( + ReadBinaryProto(env_, DatasetDefFilePath(path_), &dataset_def)); + + TF_RETURN_IF_ERROR(CreateSplitProviders(dataset_def, split_providers_)); + return OkStatus(); +} + +Status SnapshotManager::ReadOnDiskStreams() { + std::string streams_path = StreamsDirectory(path_); + + std::vector stream_directories; + TF_RETURN_IF_ERROR(env_->GetChildren(streams_path, &stream_directories)); + streams_.resize(stream_directories.size(), Stream(num_sources())); + + absl::flat_hash_set global_split_indices; + for (const auto& stream_directory : stream_directories) { + std::string stream_path = io::JoinPath(streams_path, stream_directory); + + // `stream_directory` must have this format: "stream_". + std::vector tokens = absl::StrSplit(stream_directory, '_'); + int64_t stream_index; + if (tokens.size() != 2 || !absl::SimpleAtoi(tokens[1], &stream_index) || + stream_index < 0) { + return InvalidArgument( + "can't parse the name of ", stream_path, + ": filename must have the format stream_"); + } + + TF_RETURN_IF_ERROR(ReadOnDiskStream(stream_index, global_split_indices)); + } + + for (int64_t i = 0; i < global_split_indices.size(); ++i) { + if (!global_split_indices.contains(i)) { + return InvalidArgument("found missing global split index, ", i, ", in ", + path_); + } + } + num_assigned_splits_ = global_split_indices.size(); + + return OkStatus(); +} + +Status SnapshotManager::ReadOnDiskStream( + int64_t stream_index, absl::flat_hash_set& global_split_indices) { + std::string splits_path = SplitsDirectory(path_, stream_index); + std::vector source_directories; + TF_RETURN_IF_ERROR(env_->GetChildren(splits_path, &source_directories)); + for (const auto& source_directory : source_directories) { + std::string source_path = io::JoinPath(splits_path, source_directory); + + // `source_directory` must have this format: "source_". + std::vector tokens = absl::StrSplit(source_directory, '_'); + int64_t source_index; + if (tokens.size() != 2 || !absl::SimpleAtoi(tokens[1], &source_index) || + source_index < 0) { + return InvalidArgument( + "can't parse the name of ", source_path, + ": filename must have the format source_"); + } + if (source_index >= num_sources()) { + return InvalidArgument("found conflict between the number of sources, ", + num_sources(), ", and the filename of ", + source_path); + } + TF_RETURN_IF_ERROR( + ReadOnDiskSource(stream_index, source_index, global_split_indices)); + } + + // TODO(mpcallanan): Handle unknowns. + + return OkStatus(); +} + +Status SnapshotManager::ReadOnDiskSource( + int64_t stream_index, int64_t source_index, + absl::flat_hash_set& global_split_indices) { + std::string source_path = SourceDirectory(path_, stream_index, source_index); + + std::vector split_filenames; + TF_RETURN_IF_ERROR(env_->GetChildren(source_path, &split_filenames)); + + Tensor unused_tensor; + bool unused_end_of_splits; + for (const auto& split_filename : split_filenames) { + std::string split_path = io::JoinPath(source_path, split_filename); + + // `split_filename` must have this format: + // "split__". + std::vector tokens = absl::StrSplit(split_filename, '_'); + int64_t local_split_index; + int64_t global_split_index; + if (tokens.size() != 3 || + !absl::SimpleAtoi(tokens[1], &local_split_index) || + local_split_index < 0 || + !absl::SimpleAtoi(tokens[2], &global_split_index) || + global_split_index < 0) { + return InvalidArgument("can't parse the name of ", split_path); + } + if (local_split_index > global_split_index) { + return InvalidArgument( + "found conflict between local split index and global split index in ", + "name of ", split_path); + } + if (local_split_index > split_filenames.size() - 1) { + return InvalidArgument( + "found conflict between the number of splits and name of ", + split_path); + } + if (global_split_indices.contains(global_split_index)) { + return InvalidArgument("found duplicate global split index in name of ", + split_path); + } + + // To account for this split having been assigned, skip a split in the + // respective provider. + TF_RETURN_IF_ERROR(split_providers_[source_index]->GetNext( + &unused_tensor, &unused_end_of_splits)); + global_split_indices.insert(global_split_index); + } + + streams_[stream_index].num_assigned_splits[source_index] = + split_filenames.size(); + + return OkStatus(); +} + +StatusOr SnapshotManager::CreateNewStream( + const std::string& worker_address) { + int64_t new_stream_index = streams_.size(); + + for (int64_t source_index = 0; source_index < num_sources(); ++source_index) { + TF_RETURN_IF_ERROR(env_->RecursivelyCreateDir( + SourceDirectory(path_, new_stream_index, source_index))); + } + + streams_.push_back(Stream(num_sources())); + assignments_.insert({worker_address, new_stream_index}); + VLOG(1) << "creating stream " << new_stream_index + << " and assigning it to worker " << worker_address; + + return new_stream_index; +} + +Status SnapshotManager::WorkerHeartbeat(const WorkerHeartbeatRequest& request, + WorkerHeartbeatResponse& response) { + SnapshotTaskDef* snapshot_task = response.add_snapshot_tasks(); + snapshot_task->set_base_path(path_); + snapshot_task->set_num_sources(num_sources()); + + if (auto it = assignments_.find(request.worker_address()); + it != assignments_.end()) { + snapshot_task->set_stream_index(it->second); + return OkStatus(); + } + + // TODO(mpcallanan): Handle orphans. + + TF_ASSIGN_OR_RETURN(int64_t new_stream_index, + CreateNewStream(request.worker_address())); + snapshot_task->set_stream_index(new_stream_index); + return OkStatus(); +} + +Status SnapshotManager::GetSnapshotSplit(const GetSnapshotSplitRequest& request, + GetSnapshotSplitResponse& response) { + // TODO(mpcallanan): Validate the request. + + Tensor split; + bool end_of_splits; + TF_RETURN_IF_ERROR(split_providers_[request.source_index()]->GetNext( + &split, &end_of_splits)); + + Stream& stream = streams_[request.stream_index()]; + if (end_of_splits) { + // TODO(mpcallanan): Handle doneness. + response.set_end_of_splits(true); + return OkStatus(); + } + + std::string split_path = SplitPath( + path_, request.stream_index(), request.source_index(), + stream.num_assigned_splits[request.source_index()], num_assigned_splits_); + TF_RETURN_IF_ERROR(AtomicallyWriteTFRecord(split_path, split, env_)); + + ++stream.num_assigned_splits[request.source_index()]; + ++num_assigned_splits_; + + split.AsProtoTensorContent(response.mutable_split()); + + return OkStatus(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h new file mode 100644 index 00000000000000..56592cc501f943 --- /dev/null +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -0,0 +1,129 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ + +#include +#include +#include + +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { +namespace data { + +// A helper used by `DataServiceDispatcherImpl` to manage a call to `Snapshot`. +// +// Two mirrored states are maintained: +// - An in-memory state (objects in the `SnapshotManager` instance). +// - An on-disk state (files in the `SnapshotManager::path_`). +// +// The on-disk state has this structure: +// - snapshot_path +// - DONE +// - snapshot.metadata +// - dataset_def.proto +// - chunks +// - chunk__ +// - streams +// - stream_0 +// - DONE +// - splits +// - source_0 +// - DONE +// - split__ +// - uncommitted_chucnks +// - chunk_ +// - checkpoints +// - checkpoint_ +// +class SnapshotManager { + public: + // Initiates a new snapshot process, creating a fresh in-memory state and + // writing an on-disk state to `path`. Returns an error if `path` already + // exists in the filesystem. + static tsl::StatusOr> Start( + const SnapshotRequest& request, Env* env); + // Resumes an existing snapshot process, reading from the on-disk state in + // `path` to derive an in-memory state. Returns an error if `path` is in a bad + // state. + static tsl::StatusOr> Resume( + absl::string_view path, Env* env); + + // Handles the work pertaining to this snapshot process for the respective + // `DispatcherService` API calls: + // - `WorkerHeartbeat`: Returns a stream assignment for the worker. + // - `GetSnapshotSplit`: Returns a split assignment for the worker. + tsl::Status WorkerHeartbeat(const WorkerHeartbeatRequest& request, + WorkerHeartbeatResponse& response); + tsl::Status GetSnapshotSplit(const GetSnapshotSplitRequest& request, + GetSnapshotSplitResponse& response); + + private: + SnapshotManager(absl::string_view path, Env* env) : path_(path), env_(env) {} + + // See `Start` above. + tsl::Status Start(const SnapshotRequest& request); + tsl::Status WriteOnDiskSkeleton(); + tsl::Status WriteOnDiskMetadata(const SnapshotRequest& request); + + // See `Resume` above. + tsl::Status Resume(); + tsl::Status ReadOnDiskMetadata(); + tsl::Status ReadOnDiskStreams(); + tsl::Status ReadOnDiskStream( + int64_t stream_index, absl::flat_hash_set& global_split_indices); + tsl::Status ReadOnDiskSource( + int64_t stream_index, int64_t source_index, + absl::flat_hash_set& global_split_indices); + + // Returns the id of a newly created stream assigned to the worker. + tsl::StatusOr CreateNewStream(const std::string& worker_address); + + // The filepath of the on-disk state. + std::string path_; + // A tensorflow environment interface used to write to and read from `path_`. + tsl::Env* env_; + + // A split provider for each input source of the dataset being snapshotted. + std::vector> split_providers_; + int64_t num_sources() const { return split_providers_.size(); } + + struct Stream { + explicit Stream(int64_t num_sources) : num_assigned_splits(num_sources) {} + + // A counter of assigned splits for each source. + std::vector num_assigned_splits; + }; + + // All streams for this snapshot. + std::vector streams_; + // Indices of all "assigned" streams, keyed by worker address. A stream is + // considered to be assigned if the dispatcher knows of a worker + // processing the stream and that worker is heartbeating. + absl::flat_hash_map assignments_; + + // A counter of assigned aplits for this snapshot. + int64_t num_assigned_splits_ = 0; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 1bcbb962882b01..3a7c82716b2a1c 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -609,7 +609,7 @@ Status DataServiceWorkerImpl::UpdateSnapshotWriters( const WorkerHeartbeatResponse& response) { for (const SnapshotTaskDef& snapshot_task : response.snapshot_tasks()) { DatasetDef dataset_def; - TF_RETURN_IF_ERROR(ReadTextProto( + TF_RETURN_IF_ERROR(ReadBinaryProto( Env::Default(), DatasetDefFilePath(snapshot_task.base_path()), &dataset_def)); TF_ASSIGN_OR_RETURN(std::unique_ptr iterator, diff --git a/tensorflow/python/data/experimental/kernel_tests/distributed_save_test.py b/tensorflow/python/data/experimental/kernel_tests/distributed_save_test.py index f4d32d2fc1619c..f96fc4cd79c14f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/distributed_save_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/distributed_save_test.py @@ -16,6 +16,7 @@ import os import shutil +import tempfile from absl.testing import parameterized @@ -31,14 +32,17 @@ class DistributedSaveTest(test_base.DatasetTestBase, parameterized.TestCase): def setUp(self): super().setUp() - tmpdir = self.get_temp_dir() - tmpdir = os.path.join(tmpdir, "distributed_save_test") - os.mkdir(tmpdir) - self._test_dir = tmpdir + self._test_dir = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), + "distributed_save_test", + ) def tearDown(self): super().tearDown() - shutil.rmtree(self._test_dir) + try: + shutil.rmtree(self._test_dir) + except FileNotFoundError: + pass class DistributedSaveTfDataServiceTest(data_service_test_base.TestBase, diff --git a/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py b/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py index 1e0c27a7b75948..5801981abeac1b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/service/fault_tolerance_test.py @@ -14,6 +14,8 @@ # ============================================================================== """Tests for tf.data service ops where servers are started late or preempted.""" import multiprocessing +import os +import tempfile import threading import time @@ -21,6 +23,7 @@ from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base from tensorflow.python.data.experimental.ops import data_service_ops +from tensorflow.python.data.experimental.ops import distributed_save_op from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations @@ -32,6 +35,20 @@ NO_WORK_DIR = data_service_test_base.NO_WORK_DIR +def write_file(path): + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as _: + pass + + +def splits_dir(path, stream_idx=0): + return os.path.join(path, "streams", f"stream_{stream_idx}", "splits") + + +def source_dir(path, stream_idx=0): + return os.path.join(splits_dir(path, stream_idx), "source_0") + + class FaultToleranceTest(data_service_test_base.TestBase, parameterized.TestCase): @@ -311,6 +328,111 @@ def testDistributeLargeGraphThenRegisterWorker(self, work_dir): cluster.add_worker() self.assertAllEqual(next(it), tensor) + def snapshot(self): + ds = dataset_ops.Dataset.range(10) + path = os.path.join(tempfile.mkdtemp(dir=self.get_temp_dir()), "snapshot") + cluster = data_service_test_base.TestCluster(num_workers=1) + distributed_save_op.distributed_save(ds, path, cluster.dispatcher_address()) + return cluster, path, ds + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoverySucceeds(self): + cluster, _, _ = self.snapshot() + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryBlocksOverwrite(self): + cluster, path, ds = self.snapshot() + cluster.restart_dispatcher() + with self.assertRaisesOpError("is already started or completed"): + distributed_save_op.distributed_save( + ds, path, cluster.dispatcher_address() + ) + + @combinations.generate( + combinations.times( + test_base.eager_only_combinations(), + combinations.combine( + bad_stream_dir_name=["stream_", "stream_x", "stream_-1"] + ), + ) + ) + def testSnapshotRecoveryFailsWithBadStreamName(self, bad_stream_dir_name): + cluster, path, _ = self.snapshot() + os.makedirs(os.path.join(path, "streams", bad_stream_dir_name)) + with self.assertRaisesRegex(ValueError, "can't parse"): + cluster.restart_dispatcher() + + @combinations.generate( + combinations.times( + test_base.eager_only_combinations(), + combinations.combine( + bad_source_dir_name=["source_", "source_x", "source_-1"] + ), + ) + ) + def testSnapshotRecoveryFailsWithBadSourceName(self, bad_source_dir_name): + cluster, path, _ = self.snapshot() + os.makedirs(os.path.join(splits_dir(path), bad_source_dir_name)) + with self.assertRaisesRegex(ValueError, "can't parse"): + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryFailsWithOutOfBoundsSourceName(self): + cluster, path, _ = self.snapshot() + os.makedirs(os.path.join(splits_dir(path), "source_1")) + with self.assertRaisesRegex(ValueError, "found conflict"): + cluster.restart_dispatcher() + + @combinations.generate( + combinations.times( + test_base.eager_only_combinations(), + combinations.combine( + bad_split_filename=[ + "split_", + "split_x_0", + "split_-1_0", + "split_0_x", + "split_0_-1", + ] + ), + ) + ) + def testSnapshotRecoveryFailsWithBadSplitNames(self, bad_split_filename): + cluster, path, _ = self.snapshot() + write_file(os.path.join(source_dir(path), bad_split_filename)) + with self.assertRaisesRegex(ValueError, "can't parse"): + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryFailsWithOutOfOrderSplitName(self): + cluster, path, _ = self.snapshot() + write_file(os.path.join(source_dir(path), "split_1_0")) + with self.assertRaisesRegex(ValueError, "found conflict"): + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryFailsWithOutOfBoundsSplitName(self): + cluster, path, _ = self.snapshot() + write_file(os.path.join(source_dir(path), "split_1_1")) + with self.assertRaisesRegex(ValueError, "found conflict"): + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryFailsWithMissingGlobalIndexInSplitNames(self): + cluster, path, _ = self.snapshot() + write_file(os.path.join(source_dir(path), "split_0_1")) + with self.assertRaisesRegex(ValueError, "found missing global"): + cluster.restart_dispatcher() + + @combinations.generate(test_base.eager_only_combinations()) + def testSnapshotRecoveryFailsWithDuplicateGlobalIndexInSplitName(self): + cluster, path, _ = self.snapshot() + write_file(os.path.join(source_dir(path), "split_0_1")) + write_file(os.path.join(source_dir(path, stream_idx=1), "split_0_1")) + with self.assertRaisesRegex(ValueError, "found duplicate global"): + cluster.restart_dispatcher() + if __name__ == "__main__": test.main()