Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 166 additions & 7 deletions src/mock/ray/core_worker/reference_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,43 @@

#pragma once
#include "gmock/gmock.h"
#include "ray/core_worker/reference_counter.h"
#include "ray/core_worker/reference_counter_interface.h"
namespace ray {
namespace core {

class MockReferenceCounter : public ReferenceCounterInterface {
public:
MockReferenceCounter() : ReferenceCounterInterface() {}

MOCK_METHOD1(DrainAndShutdown, void(std::function<void()> shutdown));

MOCK_CONST_METHOD0(Size, size_t());

MOCK_CONST_METHOD1(OwnedByUs, bool(const ObjectID &object_id));

MOCK_METHOD2(AddLocalReference,
void(const ObjectID &object_id, const std::string &call_sit));
void(const ObjectID &object_id, const std::string &call_site));

MOCK_METHOD4(AddBorrowedObject,
bool(const ObjectID &object_id,
const ObjectID &outer_id,
const rpc::Address &owner_address,
bool foreign_owner_already_monitoring));
MOCK_METHOD2(RemoveLocalReference,
void(const ObjectID &object_id, std::vector<ObjectID> *deleted));

MOCK_METHOD4(UpdateSubmittedTaskReferences,
void(const std::vector<ObjectID> &return_ids,
const std::vector<ObjectID> &argument_ids_to_add,
const std::vector<ObjectID> &argument_ids_to_remove,
std::vector<ObjectID> *deleted));

MOCK_METHOD1(UpdateResubmittedTaskReferences,
void(const std::vector<ObjectID> &argument_ids));

MOCK_METHOD6(UpdateFinishedTaskReferences,
void(const std::vector<ObjectID> &return_ids,
const std::vector<ObjectID> &argument_ids,
bool release_lineage,
const rpc::Address &worker_addr,
const ::google::protobuf::RepeatedPtrField<rpc::ObjectReferenceCount>
&borrowed_refs,
std::vector<ObjectID> *deleted));

MOCK_METHOD9(AddOwnedObject,
void(const ObjectID &object_id,
Expand All @@ -42,6 +63,44 @@ class MockReferenceCounter : public ReferenceCounterInterface {
const std::optional<NodeID> &pinned_at_node_id,
rpc::TensorTransport tensor_transport));

MOCK_METHOD2(AddDynamicReturn,
void(const ObjectID &object_id, const ObjectID &generator_id));

MOCK_METHOD2(OwnDynamicStreamingTaskReturnRef,
void(const ObjectID &object_id, const ObjectID &generator_id));

MOCK_METHOD2(TryReleaseLocalRefs,
void(const std::vector<ObjectID> &object_ids,
std::vector<ObjectID> *deleted));

MOCK_METHOD2(CheckGeneratorRefsLineageOutOfScope,
bool(const ObjectID &generator_id, int64_t num_objects_generated));

MOCK_METHOD2(UpdateObjectSize, void(const ObjectID &object_id, int64_t object_size));

MOCK_METHOD4(AddBorrowedObject,
bool(const ObjectID &object_id,
const ObjectID &outer_id,
const rpc::Address &owner_address,
bool foreign_owner_already_monitoring));

MOCK_CONST_METHOD2(GetOwner,
bool(const ObjectID &object_id, rpc::Address *owner_address));

MOCK_CONST_METHOD1(HasOwner, bool(const ObjectID &object_id));

MOCK_CONST_METHOD1(
HasOwner, StatusSet<StatusT::NotFound>(const std::vector<ObjectID> &object_ids));

MOCK_CONST_METHOD1(GetOwnerAddresses,
std::vector<rpc::Address>(const std::vector<ObjectID> &object_ids));

MOCK_CONST_METHOD1(IsPlasmaObjectFreed, bool(const ObjectID &object_id));

MOCK_METHOD1(TryMarkFreedObjectInUseAgain, bool(const ObjectID &object_id));

MOCK_METHOD1(FreePlasmaObjects, void(const std::vector<ObjectID> &object_ids));

MOCK_METHOD2(AddObjectOutOfScopeOrFreedCallback,
bool(const ObjectID &object_id,
const std::function<void(const ObjectID &)> callback));
Expand All @@ -50,6 +109,106 @@ class MockReferenceCounter : public ReferenceCounterInterface {
bool(const ObjectID &object_id,
const std::function<void(const ObjectID &)> callback));

MOCK_METHOD3(SubscribeRefRemoved,
void(const ObjectID &object_id,
const ObjectID &contained_in_id,
const rpc::Address &owner_address));

MOCK_METHOD1(SetReleaseLineageCallback, void(const LineageReleasedCallback &callback));

MOCK_METHOD1(PublishRefRemoved, void(const ObjectID &object_id));

MOCK_CONST_METHOD0(NumObjectIDsInScope, size_t());

MOCK_CONST_METHOD0(NumObjectsOwnedByUs, size_t());

MOCK_CONST_METHOD0(NumActorsOwnedByUs, size_t());

MOCK_CONST_METHOD0(GetAllInScopeObjectIDs, std::unordered_set<ObjectID>());

MOCK_CONST_METHOD0(GetAllReferenceCounts,
std::unordered_map<ObjectID, std::pair<size_t, size_t>>());

MOCK_CONST_METHOD0(DebugString, std::string());

MOCK_METHOD3(
PopAndClearLocalBorrowers,
void(const std::vector<ObjectID> &borrowed_ids,
::google::protobuf::RepeatedPtrField<rpc::ObjectReferenceCount> *proto,
std::vector<ObjectID> *deleted));

MOCK_METHOD3(AddNestedObjectIds,
void(const ObjectID &object_id,
const std::vector<ObjectID> &inner_ids,
const rpc::Address &owner_address));

MOCK_METHOD2(UpdateObjectPinnedAtRaylet,
void(const ObjectID &object_id, const NodeID &node_id));

MOCK_CONST_METHOD4(IsPlasmaObjectPinnedOrSpilled,
bool(const ObjectID &object_id,
bool *owned_by_us,
NodeID *pinned_at,
bool *spilled));

MOCK_METHOD1(ResetObjectsOnRemovedNode, void(const NodeID &node_id));

MOCK_METHOD0(FlushObjectsToRecover, std::vector<ObjectID>());

MOCK_CONST_METHOD1(HasReference, bool(const ObjectID &object_id));

MOCK_CONST_METHOD3(
AddObjectRefStats,
void(const absl::flat_hash_map<ObjectID, std::pair<int64_t, std::string>>
&pinned_objects,
rpc::CoreWorkerStats *stats,
const int64_t limit));

MOCK_METHOD2(AddObjectLocation, bool(const ObjectID &object_id, const NodeID &node_id));

MOCK_METHOD2(RemoveObjectLocation,
bool(const ObjectID &object_id, const NodeID &node_id));

MOCK_METHOD1(GetObjectLocations,
std::optional<absl::flat_hash_set<NodeID>>(const ObjectID &object_id));

MOCK_METHOD1(PublishObjectLocationSnapshot, void(const ObjectID &object_id));

MOCK_METHOD2(FillObjectInformation,
void(const ObjectID &object_id,
rpc::WorkerObjectLocationsPubMessage *object_info));

MOCK_METHOD3(HandleObjectSpilled,
bool(const ObjectID &object_id,
const std::string &spilled_url,
const NodeID &spilled_node_id));

MOCK_CONST_METHOD1(GetLocalityData,
std::optional<LocalityData>(const ObjectID &object_id));

MOCK_METHOD3(ReportLocalityData,
bool(const ObjectID &object_id,
const absl::flat_hash_set<NodeID> &locations,
uint64_t object_size));

MOCK_METHOD2(AddBorrowerAddress,
void(const ObjectID &object_id, const rpc::Address &borrower_address));

MOCK_CONST_METHOD2(IsObjectReconstructable,
bool(const ObjectID &object_id, bool *lineage_evicted));

MOCK_METHOD1(EvictLineage, int64_t(int64_t min_bytes_to_evict));

MOCK_METHOD2(UpdateObjectPendingCreation,
void(const ObjectID &object_id, bool pending_creation));

MOCK_CONST_METHOD1(IsObjectPendingCreation, bool(const ObjectID &object_id));

MOCK_METHOD0(ReleaseAllLocalReferences, void());

MOCK_CONST_METHOD1(GetTensorTransport,
std::optional<rpc::TensorTransport>(const ObjectID &object_id));

virtual ~MockReferenceCounter() {}
};

Expand Down
23 changes: 19 additions & 4 deletions src/ray/core_worker/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ ray_cc_library(
":actor_handle",
":common",
":core_worker_context",
":reference_counter",
":reference_counter_interface",
"//src/ray/common:id",
"//src/ray/common:protobuf_utils",
"//src/ray/common:task_common",
Expand All @@ -193,11 +193,26 @@ ray_cc_library(
],
)

ray_cc_library(
name = "reference_counter_interface",
hdrs = ["reference_counter_interface.h"],
deps = [
"//src/ray/common:id",
"//src/ray/core_worker:lease_policy",
"//src/ray/pubsub:publisher_interface",
"//src/ray/pubsub:subscriber_interface",
"//src/ray/rpc:utils",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)

ray_cc_library(
name = "reference_counter",
srcs = ["reference_counter.cc"],
hdrs = ["reference_counter.h"],
deps = [
":reference_counter_interface",
"//src/ray/common:id",
"//src/ray/core_worker:lease_policy",
"//src/ray/protobuf:common_cc_proto",
Expand Down Expand Up @@ -258,7 +273,7 @@ ray_cc_library(
hdrs = ["store_provider/memory_store/memory_store.h"],
deps = [
":core_worker_context",
":reference_counter",
":reference_counter_interface",
"//src/ray/common:asio",
"//src/ray/common:id",
"//src/ray/common:ray_config",
Expand Down Expand Up @@ -367,7 +382,7 @@ ray_cc_library(
hdrs = ["object_recovery_manager.h"],
deps = [
":memory_store",
":reference_counter",
":reference_counter_interface",
":task_manager",
"//src/ray/common:id",
"//src/ray/raylet_rpc_client:raylet_client_pool",
Expand All @@ -394,7 +409,7 @@ ray_cc_library(
deps = [
":common",
":core_worker_context",
":reference_counter",
":reference_counter_interface",
"//src/ray/common:buffer",
"//src/ray/common:id",
"//src/ray/common:ray_config",
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/actor_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "absl/container/flat_hash_map.h"
#include "ray/core_worker/actor_creator.h"
#include "ray/core_worker/actor_handle.h"
#include "ray/core_worker/reference_counter.h"
#include "ray/core_worker/reference_counter_interface.h"
#include "ray/core_worker/task_submission/actor_task_submitter.h"
#include "ray/gcs_rpc_client/gcs_client.h"
namespace ray {
Expand Down
12 changes: 6 additions & 6 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ ObjectLocation CreateObjectLocation(
}

std::optional<ObjectLocation> TryGetLocalObjectLocation(
ReferenceCounter &reference_counter, const ObjectID &object_id) {
ReferenceCounterInterface &reference_counter, const ObjectID &object_id) {
if (!reference_counter.HasReference(object_id)) {
return std::nullopt;
}
Expand Down Expand Up @@ -287,7 +287,7 @@ CoreWorker::CoreWorker(
std::shared_ptr<ipc::RayletIpcClientInterface> raylet_ipc_client,
std::shared_ptr<RayletClientInterface> local_raylet_rpc_client,
boost::thread &io_thread,
std::shared_ptr<ReferenceCounter> reference_counter,
std::shared_ptr<ReferenceCounterInterface> reference_counter,
std::shared_ptr<CoreWorkerMemoryStore> memory_store,
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider,
std::shared_ptr<experimental::MutableObjectProviderInterface>
Expand Down Expand Up @@ -1977,7 +1977,7 @@ std::vector<rpc::ObjectReference> CoreWorker::SubmitTask(
/*include_job_config=*/true,
/*generator_backpressure_num_objects=*/
task_options.generator_backpressure_num_objects,
/*enable_task_event=*/task_options.enable_task_events,
/*enable_task_events=*/task_options.enable_task_events,
task_options.labels,
task_options.label_selector);
ActorID root_detached_actor_id;
Expand Down Expand Up @@ -2707,7 +2707,7 @@ Status CoreWorker::ExecuteTask(
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>> *return_objects,
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>> *dynamic_return_objects,
std::vector<std::pair<ObjectID, bool>> *streaming_generator_returns,
ReferenceCounter::ReferenceTableProto *borrowed_refs,
ReferenceCounterInterface::ReferenceTableProto *borrowed_refs,
bool *is_retryable_error,
std::string *application_error) {
RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString();
Expand Down Expand Up @@ -3082,7 +3082,7 @@ Status CoreWorker::ReportGeneratorItemReturns(
// we borrow the object. When the object value is allocatd, the
// memory store is updated. We should clear borrowers and memory store
// here.
ReferenceCounter::ReferenceTableProto borrowed_refs;
ReferenceCounterInterface::ReferenceTableProto borrowed_refs;
reference_counter_->PopAndClearLocalBorrowers(
{dynamic_return_object.first}, &borrowed_refs, &deleted);
memory_store_->Delete(deleted);
Expand Down Expand Up @@ -3152,7 +3152,7 @@ void CoreWorker::HandleReportGeneratorItemReturns(
std::vector<rpc::ObjectReference> CoreWorker::ExecuteTaskLocalMode(
const TaskSpecification &task_spec, const ActorID &actor_id) {
auto return_objects = std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>>();
auto borrowed_refs = ReferenceCounter::ReferenceTableProto();
auto borrowed_refs = ReferenceCounterInterface::ReferenceTableProto();

std::vector<rpc::ObjectReference> returned_refs;
size_t num_returns = task_spec.NumReturns();
Expand Down
7 changes: 4 additions & 3 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "ray/core_worker/object_recovery_manager.h"
#include "ray/core_worker/profile_event.h"
#include "ray/core_worker/reference_counter.h"
#include "ray/core_worker/reference_counter_interface.h"
#include "ray/core_worker/shutdown_coordinator.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/store_provider/plasma_store_provider.h"
Expand Down Expand Up @@ -183,7 +184,7 @@ class CoreWorker {
std::shared_ptr<ipc::RayletIpcClientInterface> raylet_ipc_client,
std::shared_ptr<ray::RayletClientInterface> local_raylet_rpc_client,
boost::thread &io_thread,
std::shared_ptr<ReferenceCounter> reference_counter,
std::shared_ptr<ReferenceCounterInterface> reference_counter,
std::shared_ptr<CoreWorkerMemoryStore> memory_store,
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider,
std::shared_ptr<experimental::MutableObjectProviderInterface>
Expand Down Expand Up @@ -1483,7 +1484,7 @@ class CoreWorker {
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>>
*dynamic_return_objects,
std::vector<std::pair<ObjectID, bool>> *streaming_generator_returns,
ReferenceCounter::ReferenceTableProto *borrowed_refs,
ReferenceCounterInterface::ReferenceTableProto *borrowed_refs,
bool *is_retryable_error,
std::string *application_error);

Expand Down Expand Up @@ -1768,7 +1769,7 @@ class CoreWorker {
boost::thread &io_thread_;

// Keeps track of object ID reference counts.
std::shared_ptr<ReferenceCounter> reference_counter_;
std::shared_ptr<ReferenceCounterInterface> reference_counter_;

///
/// Fields related to storing and retrieving objects.
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/future_resolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using ReportLocalityDataCallback =
class FutureResolver {
public:
FutureResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
std::shared_ptr<ReferenceCounter> ref_counter,
std::shared_ptr<ReferenceCounterInterface> ref_counter,
ReportLocalityDataCallback report_locality_data_callback,
std::shared_ptr<rpc::CoreWorkerClientPool> core_worker_client_pool,
rpc::Address rpc_address)
Expand Down Expand Up @@ -69,7 +69,7 @@ class FutureResolver {
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;

/// Used to record nested ObjectRefs of resolved futures.
std::shared_ptr<ReferenceCounter> reference_counter_;
std::shared_ptr<ReferenceCounterInterface> reference_counter_;

/// Used to report locality data received during future resolution.
ReportLocalityDataCallback report_locality_data_callback_;
Expand Down
Loading