Skip to content

Commit fa874f3

Browse files
Kunchddavik
andauthored
[Core] Abstract reference counter behind interface for more defined API (#57177)
Signed-off-by: davik <davik@anyscale.com> Co-authored-by: davik <davik@anyscale.com>
1 parent a6555d5 commit fa874f3

21 files changed

+846
-463
lines changed

src/mock/ray/core_worker/reference_counter.h

Lines changed: 166 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,43 @@
1414

1515
#pragma once
1616
#include "gmock/gmock.h"
17-
#include "ray/core_worker/reference_counter.h"
17+
#include "ray/core_worker/reference_counter_interface.h"
1818
namespace ray {
1919
namespace core {
2020

2121
class MockReferenceCounter : public ReferenceCounterInterface {
2222
public:
2323
MockReferenceCounter() : ReferenceCounterInterface() {}
2424

25+
MOCK_METHOD1(DrainAndShutdown, void(std::function<void()> shutdown));
26+
27+
MOCK_CONST_METHOD0(Size, size_t());
28+
29+
MOCK_CONST_METHOD1(OwnedByUs, bool(const ObjectID &object_id));
30+
2531
MOCK_METHOD2(AddLocalReference,
26-
void(const ObjectID &object_id, const std::string &call_sit));
32+
void(const ObjectID &object_id, const std::string &call_site));
2733

28-
MOCK_METHOD4(AddBorrowedObject,
29-
bool(const ObjectID &object_id,
30-
const ObjectID &outer_id,
31-
const rpc::Address &owner_address,
32-
bool foreign_owner_already_monitoring));
34+
MOCK_METHOD2(RemoveLocalReference,
35+
void(const ObjectID &object_id, std::vector<ObjectID> *deleted));
36+
37+
MOCK_METHOD4(UpdateSubmittedTaskReferences,
38+
void(const std::vector<ObjectID> &return_ids,
39+
const std::vector<ObjectID> &argument_ids_to_add,
40+
const std::vector<ObjectID> &argument_ids_to_remove,
41+
std::vector<ObjectID> *deleted));
42+
43+
MOCK_METHOD1(UpdateResubmittedTaskReferences,
44+
void(const std::vector<ObjectID> &argument_ids));
45+
46+
MOCK_METHOD6(UpdateFinishedTaskReferences,
47+
void(const std::vector<ObjectID> &return_ids,
48+
const std::vector<ObjectID> &argument_ids,
49+
bool release_lineage,
50+
const rpc::Address &worker_addr,
51+
const ::google::protobuf::RepeatedPtrField<rpc::ObjectReferenceCount>
52+
&borrowed_refs,
53+
std::vector<ObjectID> *deleted));
3354

3455
MOCK_METHOD9(AddOwnedObject,
3556
void(const ObjectID &object_id,
@@ -42,6 +63,44 @@ class MockReferenceCounter : public ReferenceCounterInterface {
4263
const std::optional<NodeID> &pinned_at_node_id,
4364
rpc::TensorTransport tensor_transport));
4465

66+
MOCK_METHOD2(AddDynamicReturn,
67+
void(const ObjectID &object_id, const ObjectID &generator_id));
68+
69+
MOCK_METHOD2(OwnDynamicStreamingTaskReturnRef,
70+
void(const ObjectID &object_id, const ObjectID &generator_id));
71+
72+
MOCK_METHOD2(TryReleaseLocalRefs,
73+
void(const std::vector<ObjectID> &object_ids,
74+
std::vector<ObjectID> *deleted));
75+
76+
MOCK_METHOD2(CheckGeneratorRefsLineageOutOfScope,
77+
bool(const ObjectID &generator_id, int64_t num_objects_generated));
78+
79+
MOCK_METHOD2(UpdateObjectSize, void(const ObjectID &object_id, int64_t object_size));
80+
81+
MOCK_METHOD4(AddBorrowedObject,
82+
bool(const ObjectID &object_id,
83+
const ObjectID &outer_id,
84+
const rpc::Address &owner_address,
85+
bool foreign_owner_already_monitoring));
86+
87+
MOCK_CONST_METHOD2(GetOwner,
88+
bool(const ObjectID &object_id, rpc::Address *owner_address));
89+
90+
MOCK_CONST_METHOD1(HasOwner, bool(const ObjectID &object_id));
91+
92+
MOCK_CONST_METHOD1(
93+
HasOwner, StatusSet<StatusT::NotFound>(const std::vector<ObjectID> &object_ids));
94+
95+
MOCK_CONST_METHOD1(GetOwnerAddresses,
96+
std::vector<rpc::Address>(const std::vector<ObjectID> &object_ids));
97+
98+
MOCK_CONST_METHOD1(IsPlasmaObjectFreed, bool(const ObjectID &object_id));
99+
100+
MOCK_METHOD1(TryMarkFreedObjectInUseAgain, bool(const ObjectID &object_id));
101+
102+
MOCK_METHOD1(FreePlasmaObjects, void(const std::vector<ObjectID> &object_ids));
103+
45104
MOCK_METHOD2(AddObjectOutOfScopeOrFreedCallback,
46105
bool(const ObjectID &object_id,
47106
const std::function<void(const ObjectID &)> callback));
@@ -50,6 +109,106 @@ class MockReferenceCounter : public ReferenceCounterInterface {
50109
bool(const ObjectID &object_id,
51110
const std::function<void(const ObjectID &)> callback));
52111

112+
MOCK_METHOD3(SubscribeRefRemoved,
113+
void(const ObjectID &object_id,
114+
const ObjectID &contained_in_id,
115+
const rpc::Address &owner_address));
116+
117+
MOCK_METHOD1(SetReleaseLineageCallback, void(const LineageReleasedCallback &callback));
118+
119+
MOCK_METHOD1(PublishRefRemoved, void(const ObjectID &object_id));
120+
121+
MOCK_CONST_METHOD0(NumObjectIDsInScope, size_t());
122+
123+
MOCK_CONST_METHOD0(NumObjectsOwnedByUs, size_t());
124+
125+
MOCK_CONST_METHOD0(NumActorsOwnedByUs, size_t());
126+
127+
MOCK_CONST_METHOD0(GetAllInScopeObjectIDs, std::unordered_set<ObjectID>());
128+
129+
MOCK_CONST_METHOD0(GetAllReferenceCounts,
130+
std::unordered_map<ObjectID, std::pair<size_t, size_t>>());
131+
132+
MOCK_CONST_METHOD0(DebugString, std::string());
133+
134+
MOCK_METHOD3(
135+
PopAndClearLocalBorrowers,
136+
void(const std::vector<ObjectID> &borrowed_ids,
137+
::google::protobuf::RepeatedPtrField<rpc::ObjectReferenceCount> *proto,
138+
std::vector<ObjectID> *deleted));
139+
140+
MOCK_METHOD3(AddNestedObjectIds,
141+
void(const ObjectID &object_id,
142+
const std::vector<ObjectID> &inner_ids,
143+
const rpc::Address &owner_address));
144+
145+
MOCK_METHOD2(UpdateObjectPinnedAtRaylet,
146+
void(const ObjectID &object_id, const NodeID &node_id));
147+
148+
MOCK_CONST_METHOD4(IsPlasmaObjectPinnedOrSpilled,
149+
bool(const ObjectID &object_id,
150+
bool *owned_by_us,
151+
NodeID *pinned_at,
152+
bool *spilled));
153+
154+
MOCK_METHOD1(ResetObjectsOnRemovedNode, void(const NodeID &node_id));
155+
156+
MOCK_METHOD0(FlushObjectsToRecover, std::vector<ObjectID>());
157+
158+
MOCK_CONST_METHOD1(HasReference, bool(const ObjectID &object_id));
159+
160+
MOCK_CONST_METHOD3(
161+
AddObjectRefStats,
162+
void(const absl::flat_hash_map<ObjectID, std::pair<int64_t, std::string>>
163+
&pinned_objects,
164+
rpc::CoreWorkerStats *stats,
165+
const int64_t limit));
166+
167+
MOCK_METHOD2(AddObjectLocation, bool(const ObjectID &object_id, const NodeID &node_id));
168+
169+
MOCK_METHOD2(RemoveObjectLocation,
170+
bool(const ObjectID &object_id, const NodeID &node_id));
171+
172+
MOCK_METHOD1(GetObjectLocations,
173+
std::optional<absl::flat_hash_set<NodeID>>(const ObjectID &object_id));
174+
175+
MOCK_METHOD1(PublishObjectLocationSnapshot, void(const ObjectID &object_id));
176+
177+
MOCK_METHOD2(FillObjectInformation,
178+
void(const ObjectID &object_id,
179+
rpc::WorkerObjectLocationsPubMessage *object_info));
180+
181+
MOCK_METHOD3(HandleObjectSpilled,
182+
bool(const ObjectID &object_id,
183+
const std::string &spilled_url,
184+
const NodeID &spilled_node_id));
185+
186+
MOCK_CONST_METHOD1(GetLocalityData,
187+
std::optional<LocalityData>(const ObjectID &object_id));
188+
189+
MOCK_METHOD3(ReportLocalityData,
190+
bool(const ObjectID &object_id,
191+
const absl::flat_hash_set<NodeID> &locations,
192+
uint64_t object_size));
193+
194+
MOCK_METHOD2(AddBorrowerAddress,
195+
void(const ObjectID &object_id, const rpc::Address &borrower_address));
196+
197+
MOCK_CONST_METHOD2(IsObjectReconstructable,
198+
bool(const ObjectID &object_id, bool *lineage_evicted));
199+
200+
MOCK_METHOD1(EvictLineage, int64_t(int64_t min_bytes_to_evict));
201+
202+
MOCK_METHOD2(UpdateObjectPendingCreation,
203+
void(const ObjectID &object_id, bool pending_creation));
204+
205+
MOCK_CONST_METHOD1(IsObjectPendingCreation, bool(const ObjectID &object_id));
206+
207+
MOCK_METHOD0(ReleaseAllLocalReferences, void());
208+
209+
MOCK_CONST_METHOD1(GetTensorTransport,
210+
std::optional<rpc::TensorTransport>(const ObjectID &object_id));
211+
53212
virtual ~MockReferenceCounter() {}
54213
};
55214

src/ray/core_worker/BUILD.bazel

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ ray_cc_library(
180180
":actor_handle",
181181
":common",
182182
":core_worker_context",
183-
":reference_counter",
183+
":reference_counter_interface",
184184
"//src/ray/common:id",
185185
"//src/ray/common:protobuf_utils",
186186
"//src/ray/common:task_common",
@@ -192,11 +192,26 @@ ray_cc_library(
192192
],
193193
)
194194

195+
ray_cc_library(
196+
name = "reference_counter_interface",
197+
hdrs = ["reference_counter_interface.h"],
198+
deps = [
199+
"//src/ray/common:id",
200+
"//src/ray/core_worker:lease_policy",
201+
"//src/ray/pubsub:publisher_interface",
202+
"//src/ray/pubsub:subscriber_interface",
203+
"//src/ray/rpc:utils",
204+
"@com_google_absl//absl/base:core_headers",
205+
"@com_google_absl//absl/synchronization",
206+
],
207+
)
208+
195209
ray_cc_library(
196210
name = "reference_counter",
197211
srcs = ["reference_counter.cc"],
198212
hdrs = ["reference_counter.h"],
199213
deps = [
214+
":reference_counter_interface",
200215
"//src/ray/common:id",
201216
"//src/ray/core_worker:lease_policy",
202217
"//src/ray/protobuf:common_cc_proto",
@@ -257,7 +272,7 @@ ray_cc_library(
257272
hdrs = ["store_provider/memory_store/memory_store.h"],
258273
deps = [
259274
":core_worker_context",
260-
":reference_counter",
275+
":reference_counter_interface",
261276
"//src/ray/common:asio",
262277
"//src/ray/common:id",
263278
"//src/ray/common:ray_config",
@@ -366,7 +381,7 @@ ray_cc_library(
366381
hdrs = ["object_recovery_manager.h"],
367382
deps = [
368383
":memory_store",
369-
":reference_counter",
384+
":reference_counter_interface",
370385
":task_manager",
371386
"//src/ray/common:id",
372387
"//src/ray/raylet_rpc_client:raylet_client_pool",
@@ -393,7 +408,7 @@ ray_cc_library(
393408
deps = [
394409
":common",
395410
":core_worker_context",
396-
":reference_counter",
411+
":reference_counter_interface",
397412
"//src/ray/common:buffer",
398413
"//src/ray/common:id",
399414
"//src/ray/common:ray_config",

src/ray/core_worker/actor_manager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#include "absl/container/flat_hash_map.h"
2525
#include "ray/core_worker/actor_creator.h"
2626
#include "ray/core_worker/actor_handle.h"
27-
#include "ray/core_worker/reference_counter.h"
27+
#include "ray/core_worker/reference_counter_interface.h"
2828
#include "ray/core_worker/task_submission/actor_task_submitter.h"
2929
#include "ray/gcs_rpc_client/gcs_client.h"
3030
namespace ray {

src/ray/core_worker/core_worker.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ ObjectLocation CreateObjectLocation(
111111
}
112112

113113
std::optional<ObjectLocation> TryGetLocalObjectLocation(
114-
ReferenceCounter &reference_counter, const ObjectID &object_id) {
114+
ReferenceCounterInterface &reference_counter, const ObjectID &object_id) {
115115
if (!reference_counter.HasReference(object_id)) {
116116
return std::nullopt;
117117
}
@@ -285,7 +285,7 @@ CoreWorker::CoreWorker(
285285
std::shared_ptr<ipc::RayletIpcClientInterface> raylet_ipc_client,
286286
std::shared_ptr<RayletClientInterface> local_raylet_rpc_client,
287287
boost::thread &io_thread,
288-
std::shared_ptr<ReferenceCounter> reference_counter,
288+
std::shared_ptr<ReferenceCounterInterface> reference_counter,
289289
std::shared_ptr<CoreWorkerMemoryStore> memory_store,
290290
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider,
291291
std::shared_ptr<experimental::MutableObjectProviderInterface>
@@ -1974,7 +1974,7 @@ std::vector<rpc::ObjectReference> CoreWorker::SubmitTask(
19741974
/*include_job_config=*/true,
19751975
/*generator_backpressure_num_objects=*/
19761976
task_options.generator_backpressure_num_objects,
1977-
/*enable_task_event=*/task_options.enable_task_events,
1977+
/*enable_task_events=*/task_options.enable_task_events,
19781978
task_options.labels,
19791979
task_options.label_selector);
19801980
ActorID root_detached_actor_id;
@@ -2704,7 +2704,7 @@ Status CoreWorker::ExecuteTask(
27042704
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>> *return_objects,
27052705
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>> *dynamic_return_objects,
27062706
std::vector<std::pair<ObjectID, bool>> *streaming_generator_returns,
2707-
ReferenceCounter::ReferenceTableProto *borrowed_refs,
2707+
ReferenceCounterInterface::ReferenceTableProto *borrowed_refs,
27082708
bool *is_retryable_error,
27092709
std::string *application_error) {
27102710
RAY_LOG(DEBUG) << "Executing task, task info = " << task_spec.DebugString();
@@ -3079,7 +3079,7 @@ Status CoreWorker::ReportGeneratorItemReturns(
30793079
// we borrow the object. When the object value is allocatd, the
30803080
// memory store is updated. We should clear borrowers and memory store
30813081
// here.
3082-
ReferenceCounter::ReferenceTableProto borrowed_refs;
3082+
ReferenceCounterInterface::ReferenceTableProto borrowed_refs;
30833083
reference_counter_->PopAndClearLocalBorrowers(
30843084
{dynamic_return_object.first}, &borrowed_refs, &deleted);
30853085
memory_store_->Delete(deleted);
@@ -3149,7 +3149,7 @@ void CoreWorker::HandleReportGeneratorItemReturns(
31493149
std::vector<rpc::ObjectReference> CoreWorker::ExecuteTaskLocalMode(
31503150
const TaskSpecification &task_spec, const ActorID &actor_id) {
31513151
auto return_objects = std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>>();
3152-
auto borrowed_refs = ReferenceCounter::ReferenceTableProto();
3152+
auto borrowed_refs = ReferenceCounterInterface::ReferenceTableProto();
31533153

31543154
std::vector<rpc::ObjectReference> returned_refs;
31553155
size_t num_returns = task_spec.NumReturns();

src/ray/core_worker/core_worker.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "ray/core_worker/object_recovery_manager.h"
4343
#include "ray/core_worker/profile_event.h"
4444
#include "ray/core_worker/reference_counter.h"
45+
#include "ray/core_worker/reference_counter_interface.h"
4546
#include "ray/core_worker/shutdown_coordinator.h"
4647
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
4748
#include "ray/core_worker/store_provider/plasma_store_provider.h"
@@ -181,7 +182,7 @@ class CoreWorker {
181182
std::shared_ptr<ipc::RayletIpcClientInterface> raylet_ipc_client,
182183
std::shared_ptr<ray::RayletClientInterface> local_raylet_rpc_client,
183184
boost::thread &io_thread,
184-
std::shared_ptr<ReferenceCounter> reference_counter,
185+
std::shared_ptr<ReferenceCounterInterface> reference_counter,
185186
std::shared_ptr<CoreWorkerMemoryStore> memory_store,
186187
std::shared_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider,
187188
std::shared_ptr<experimental::MutableObjectProviderInterface>
@@ -1480,7 +1481,7 @@ class CoreWorker {
14801481
std::vector<std::pair<ObjectID, std::shared_ptr<RayObject>>>
14811482
*dynamic_return_objects,
14821483
std::vector<std::pair<ObjectID, bool>> *streaming_generator_returns,
1483-
ReferenceCounter::ReferenceTableProto *borrowed_refs,
1484+
ReferenceCounterInterface::ReferenceTableProto *borrowed_refs,
14841485
bool *is_retryable_error,
14851486
std::string *application_error);
14861487

@@ -1765,7 +1766,7 @@ class CoreWorker {
17651766
boost::thread &io_thread_;
17661767

17671768
// Keeps track of object ID reference counts.
1768-
std::shared_ptr<ReferenceCounter> reference_counter_;
1769+
std::shared_ptr<ReferenceCounterInterface> reference_counter_;
17691770

17701771
///
17711772
/// Fields related to storing and retrieving objects.

src/ray/core_worker/future_resolver.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using ReportLocalityDataCallback =
3333
class FutureResolver {
3434
public:
3535
FutureResolver(std::shared_ptr<CoreWorkerMemoryStore> store,
36-
std::shared_ptr<ReferenceCounter> ref_counter,
36+
std::shared_ptr<ReferenceCounterInterface> ref_counter,
3737
ReportLocalityDataCallback report_locality_data_callback,
3838
std::shared_ptr<rpc::CoreWorkerClientPool> core_worker_client_pool,
3939
rpc::Address rpc_address)
@@ -69,7 +69,7 @@ class FutureResolver {
6969
std::shared_ptr<CoreWorkerMemoryStore> in_memory_store_;
7070

7171
/// Used to record nested ObjectRefs of resolved futures.
72-
std::shared_ptr<ReferenceCounter> reference_counter_;
72+
std::shared_ptr<ReferenceCounterInterface> reference_counter_;
7373

7474
/// Used to report locality data received during future resolution.
7575
ReportLocalityDataCallback report_locality_data_callback_;

0 commit comments

Comments
 (0)