Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676301454
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 19, 2024
1 parent 7885a0a commit f253df1
Show file tree
Hide file tree
Showing 18 changed files with 245 additions and 114 deletions.
5 changes: 2 additions & 3 deletions mediapipe/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,8 @@ cc_library(
deps = [
":graph_service",
":packet",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
],
)

Expand Down Expand Up @@ -864,6 +863,7 @@ cc_test(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:sink",
"//mediapipe/gpu:gpu_service",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -1209,7 +1209,6 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

Expand Down
6 changes: 4 additions & 2 deletions mediapipe/framework/calculator_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ class CalculatorContext {

// Returns the graph-level service manager for sharing its services with
// calculator-nested MP graphs.
std::shared_ptr<GraphServiceManager> GetSharedGraphServiceManager() const {
return calculator_state_->GetSharedGraphServiceManager();
// Note: For accessing MP services from a calculator, use the
// ServiceBinding<T> Service(kService) method above.
const GraphServiceManager* GetGraphServiceManager() const {
return calculator_state_->GetGraphServiceManager();
}

// Gets interface to access resources (file system, assets, etc.) from
Expand Down
50 changes: 26 additions & 24 deletions mediapipe/framework/calculator_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,21 @@ void CalculatorGraph::GraphInputStream::Close() {
manager_->Close();
}

CalculatorGraph::CalculatorGraph(
std::shared_ptr<GraphServiceManager> service_manager)
: counter_factory_(std::make_unique<BasicCounterFactory>()),
service_manager_(std::move(service_manager)),
profiler_(std::make_shared<ProfilingContext>()),
scheduler_(this) {}

CalculatorGraph::CalculatorGraph()
: CalculatorGraph(std::make_shared<GraphServiceManager>()) {}
CalculatorGraph::CalculatorGraph() : CalculatorGraph(/*cc=*/nullptr) {}

// Adopt all services from the CalculatorContext / parent graph.
CalculatorGraph::CalculatorGraph(CalculatorContext* cc)
: CalculatorGraph(cc->GetSharedGraphServiceManager()) {}
: counter_factory_(std::make_unique<BasicCounterFactory>()),
service_manager_(cc != nullptr ? cc->GetGraphServiceManager() : nullptr),
profiler_(std::make_shared<ProfilingContext>()),
scheduler_(this) {
if (cc != nullptr) {
// Nested graphs should not create default initialized services to avoid
// collisions between newly created and inherited graphs.
// TODO b/368015341- Use factory method to avoid CHECK in constructor.
ABSL_CHECK_OK(DisallowServiceDefaultInitialization());
}
}

CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config)
: CalculatorGraph() {
Expand Down Expand Up @@ -280,7 +282,7 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_, service_manager_);
&buffer_size_hint, profiler_, &service_manager_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (buffer_size_hint > 0) {
max_queue_size_ = std::max(max_queue_size_, buffer_size_hint);
Expand Down Expand Up @@ -318,7 +320,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
const absl::Status result = nodes_.back()->Initialize(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_, service_manager_);
&buffer_size_hint, profiler_, &service_manager_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (!result.ok()) {
// Collect as many errors as we can before failing.
Expand Down Expand Up @@ -467,7 +469,7 @@ absl::Status CalculatorGraph::Initialize(
auto validated_graph = std::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize(
std::move(input_config), /*graph_registry=*/nullptr,
/*graph_options=*/nullptr, service_manager_));
/*graph_options=*/nullptr, &service_manager_));
return Initialize(std::move(validated_graph), side_packets);
}

Expand All @@ -478,7 +480,7 @@ absl::Status CalculatorGraph::Initialize(
const std::string& graph_type, const Subgraph::SubgraphOptions* options) {
auto validated_graph = std::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize(
input_configs, input_templates, graph_type, options, service_manager_));
input_configs, input_templates, graph_type, options, &service_manager_));
return Initialize(std::move(validated_graph), side_packets);
}

Expand Down Expand Up @@ -596,15 +598,15 @@ absl::Status CalculatorGraph::StartRun(
absl::Status CalculatorGraph::SetGpuResources(
std::shared_ptr<::mediapipe::GpuResources> resources) {
RET_CHECK_NE(resources, nullptr);
auto gpu_service = service_manager_->GetServiceObject(kGpuService);
auto gpu_service = service_manager_.GetServiceObject(kGpuService);
RET_CHECK_EQ(gpu_service, nullptr)
<< "The GPU resources have already been configured.";
return service_manager_->SetServiceObject(kGpuService, std::move(resources));
return service_manager_.SetServiceObject(kGpuService, std::move(resources));
}

std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
const {
return service_manager_->GetServiceObject(kGpuService);
return service_manager_.GetServiceObject(kGpuService);
}

static Packet GetLegacyGpuSharedSidePacket(
Expand All @@ -620,21 +622,21 @@ static Packet GetLegacyGpuSharedSidePacket(
absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket(
Packet legacy_sp) {
if (legacy_sp.IsEmpty()) return absl::OkStatus();
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources) {
ABSL_LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet";
return absl::OkStatus();
}
gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources;
return service_manager_->SetServiceObject(kGpuService, gpu_resources);
return service_manager_.SetServiceObject(kGpuService, gpu_resources);
}

std::map<std::string, Packet> CalculatorGraph::MaybeCreateLegacyGpuSidePacket(
Packet legacy_sp) {
std::map<std::string, Packet> additional_side_packets;
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources &&
(legacy_sp.IsEmpty() ||
legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources !=
Expand All @@ -652,7 +654,7 @@ static bool UsesGpu(const CalculatorNode& node) {
}

absl::Status CalculatorGraph::PrepareGpu() {
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (!gpu_resources) return absl::OkStatus();
// Set up executors.
for (auto& node : nodes_) {
Expand All @@ -671,7 +673,7 @@ absl::Status CalculatorGraph::PrepareGpu() {
absl::Status CalculatorGraph::PrepareServices() {
for (const auto& node : nodes_) {
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
auto packet = service_manager_->GetServicePacket(request.Service());
auto packet = service_manager_.GetServicePacket(request.Service());
if (!packet.IsEmpty()) continue;
absl::StatusOr<Packet> packet_or;
if (allow_service_default_initialization_) {
Expand All @@ -681,7 +683,7 @@ absl::Status CalculatorGraph::PrepareServices() {
"Service default initialization is disallowed.");
}
if (packet_or.ok()) {
MP_RETURN_IF_ERROR(service_manager_->SetServicePacket(
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
request.Service(), std::move(packet_or).value()));
} else if (request.IsOptional()) {
continue;
Expand Down Expand Up @@ -802,7 +804,7 @@ absl::Status CalculatorGraph::PrepareForRun(
// TODO: update calculator node to use GraphServiceManager
// instead of service packets?
const absl::Status result = node->PrepareForRun(
current_run_side_packets_, service_manager_->ServicePackets(),
current_run_side_packets_, service_manager_.ServicePackets(),
std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_,
node.get()),
std::bind(&internal::Scheduler::AddNodeToSourcesQueue, &scheduler_,
Expand Down
11 changes: 4 additions & 7 deletions mediapipe/framework/calculator_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,12 @@ class CalculatorGraph {
}
}

return service_manager_->SetServiceObject(service, object);
return service_manager_.SetServiceObject(service, object);
}

template <typename T>
std::shared_ptr<T> GetServiceObject(const GraphService<T>& service) {
return service_manager_->GetServiceObject(service);
return service_manager_.GetServiceObject(service);
}

// Disallows/disables default initialization of MediaPipe graph services.
Expand Down Expand Up @@ -483,13 +483,10 @@ class CalculatorGraph {
// Only the Java API should call this directly.
absl::Status SetServicePacket(const GraphServiceBase& service, Packet p) {
// TODO: check that the graph has not been started!
return service_manager_->SetServicePacket(service, p);
return service_manager_.SetServicePacket(service, p);
}

private:
explicit CalculatorGraph(
std::shared_ptr<GraphServiceManager> service_manager);

// GraphRunState is used as a parameter in the function CallStatusHandlers.
enum class GraphRunState {
// State of the graph before the run; see status_handler.h for details.
Expand Down Expand Up @@ -722,7 +719,7 @@ class CalculatorGraph {
std::map<std::string, Packet> current_run_side_packets_;

// Object to manage graph services.
std::shared_ptr<GraphServiceManager> service_manager_;
GraphServiceManager service_manager_;

// Indicates whether service default initialization is allowed.
bool allow_service_default_initialization_ = true;
Expand Down
5 changes: 3 additions & 2 deletions mediapipe/framework/calculator_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_state.h"
#include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/input_stream_manager.h"
Expand Down Expand Up @@ -128,7 +129,7 @@ absl::Status CalculatorNode::Initialize(
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context,
std::shared_ptr<GraphServiceManager> graph_service_manager) {
const GraphServiceManager* graph_service_manager) {
RET_CHECK(buffer_size_hint) << "buffer_size_hint is NULL";
validated_graph_ = validated_graph;
profiling_context_ = profiling_context;
Expand Down Expand Up @@ -171,7 +172,7 @@ absl::Status CalculatorNode::Initialize(
node_type_info_->OutputStreamTypes()));
MP_RETURN_IF_ERROR(InitializeOutputStreams(output_stream_managers));

calculator_state_ = absl::make_unique<CalculatorState>(
calculator_state_ = std::make_unique<CalculatorState>(
name_, node_ref.index, node_config->calculator(), *node_config,
profiling_context_, graph_service_manager);

Expand Down
15 changes: 8 additions & 7 deletions mediapipe/framework/calculator_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ class CalculatorNode {
// output_side_packets is expected to point to a contiguous flat array with
// OutputSidePacketImpls corresponding to the output side packet indexes in
// validated_graph.
absl::Status Initialize(
const ValidatedGraphConfig* validated_graph,
NodeTypeInfo::NodeRef node_ref, InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets, int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context,
std::shared_ptr<GraphServiceManager> graph_service_manager);
absl::Status Initialize(const ValidatedGraphConfig* validated_graph,
NodeTypeInfo::NodeRef node_ref,
InputStreamManager* input_stream_managers,
OutputStreamManager* output_stream_managers,
OutputSidePacketImpl* output_side_packets,
int* buffer_size_hint,
std::shared_ptr<ProfilingContext> profiling_context,
const GraphServiceManager* graph_service_manager);

// Sets up the node at the beginning of CalculatorGraph::Run(). This
// method is executed before any OpenNode() calls to the nodes
Expand Down
2 changes: 1 addition & 1 deletion mediapipe/framework/calculator_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ CalculatorState::CalculatorState(
const std::string& calculator_type,
const CalculatorGraphConfig::Node& node_config,
std::shared_ptr<ProfilingContext> profiling_context,
std::shared_ptr<GraphServiceManager> graph_service_manager)
const GraphServiceManager* graph_service_manager)
: node_name_(node_name),
node_id_(node_id),
calculator_type_(calculator_type),
Expand Down
9 changes: 4 additions & 5 deletions mediapipe/framework/calculator_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CalculatorState {
const std::string& calculator_type,
const CalculatorGraphConfig::Node& node_config,
std::shared_ptr<ProfilingContext> profiling_context,
std::shared_ptr<GraphServiceManager> graph_service_manager);
const GraphServiceManager* graph_service_manager);
CalculatorState(const CalculatorState&) = delete;
CalculatorState& operator=(const CalculatorState&) = delete;
~CalculatorState();
Expand Down Expand Up @@ -97,8 +97,7 @@ class CalculatorState {

// Returns the graph-level service manager for sharing its services with
// calculator-nested MP graphs.
const std::shared_ptr<GraphServiceManager>& GetSharedGraphServiceManager()
const {
const GraphServiceManager* GetGraphServiceManager() const {
return graph_service_manager_;
}

Expand Down Expand Up @@ -144,8 +143,8 @@ class CalculatorState {
// The graph tracing and profiling interface.
std::shared_ptr<ProfilingContext> profiling_context_;

// Shared pointer to the graph-level service manager.
std::shared_ptr<GraphServiceManager> graph_service_manager_;
// Const pointer to the graph-level service manager.
const GraphServiceManager* graph_service_manager_ = nullptr;

// calculator_service_manager_ contains only the services that are requested
// by the calculator in UpdateContract() via cc->UseService(...).
Expand Down
8 changes: 5 additions & 3 deletions mediapipe/framework/graph_service_manager.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
#include "mediapipe/framework/graph_service_manager.h"

#include "absl/synchronization/mutex.h"
#include <utility>

#include "absl/status/status.h"
#include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/packet.h"

namespace mediapipe {

absl::Status GraphServiceManager::SetServicePacket(
const GraphServiceBase& service, Packet p) {
// TODO: check service is already set?
absl::MutexLock lock(&service_packets_mutex_);
service_packets_[service.key] = std::move(p);
return absl::OkStatus();
}

Packet GraphServiceManager::GetServicePacket(
const GraphServiceBase& service) const {
absl::MutexLock lock(&service_packets_mutex_);
auto it = service_packets_.find(service.key);
if (it == service_packets_.end()) {
return {};
Expand Down
25 changes: 16 additions & 9 deletions mediapipe/framework/graph_service_manager.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
#ifndef MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_MANAGER_H_
#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_MANAGER_H_

#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/packet.h"

namespace mediapipe {

class GraphServiceManager {
public:
GraphServiceManager() = default;

explicit GraphServiceManager(
const GraphServiceManager* external_graph_manager) {
if (external_graph_manager != nullptr) {
// Nested graphs inherit all graph services from their parent graph and
// disable the registration of new services in the nested graph. This
// ensures that all services are created during the initialization of
// parent graph.
service_packets_ = external_graph_manager->ServicePackets();
}
}

using ServiceMap = std::map<std::string, Packet>;

template <typename T>
Expand All @@ -34,14 +43,12 @@ class GraphServiceManager {
if (p.IsEmpty()) return nullptr;
return p.Get<std::shared_ptr<T>>();
}
const ServiceMap& ServicePackets() { return service_packets_; }
const ServiceMap& ServicePackets() const { return service_packets_; }

private:
Packet GetServicePacket(const GraphServiceBase& service) const;
// Mutex protection since the GraphServiceManager instance can be shared among
// multiple (nested) MP graphs.
mutable absl::Mutex service_packets_mutex_;
ServiceMap service_packets_ ABSL_GUARDED_BY(service_packets_mutex_);

ServiceMap service_packets_;
friend class CalculatorGraph;
};

Expand Down
Loading

0 comments on commit f253df1

Please sign in to comment.