Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring in Pipeline, old Executor + name lookup improvement in old OpGraph. #5495

Merged
merged 10 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void AsyncSeparatedPipelinedExecutor::Prefetch() {
}
}

int AsyncSeparatedPipelinedExecutor::InputFeedCount(const std::string &op_name) {
int AsyncSeparatedPipelinedExecutor::InputFeedCount(std::string_view op_name) {
(void)graph_->Node(op_name);
return queue_sizes_.cpu_size + queue_sizes_.gpu_size;
}
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/executor/async_separated_pipelined_executor.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -91,7 +91,7 @@ class DLL_PUBLIC AsyncSeparatedPipelinedExecutor : public SeparatedPipelinedExec
}
}

DLL_PUBLIC int InputFeedCount(const std::string &op_name) override;
DLL_PUBLIC int InputFeedCount(std::string_view op_name) override;

protected:
DLL_PUBLIC void Prefetch() override;
Expand Down
8 changes: 6 additions & 2 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
#include "dali/core/common.h"
#include "dali/pipeline/workspace/workspace.h"
#include "dali/pipeline/operator/checkpointing/checkpoint.h"
#include "dali/pipeline/graph/op_graph2.h"

namespace dali {

class OperatorBase;

struct DLL_PUBLIC ExecutorMeta {
size_t real_size;
size_t max_real_size;
Expand All @@ -37,7 +40,7 @@ using ExecutorMetaMap = std::unordered_map<std::string, std::vector<ExecutorMeta
class DLL_PUBLIC ExecutorBase {
public:
DLL_PUBLIC virtual ~ExecutorBase() {}
DLL_PUBLIC virtual void Build(OpGraph *graph, vector<string> output_names) = 0;
DLL_PUBLIC virtual void Build(OpGraph *graph, vector<string> outputs) = 0;
stiepan marked this conversation as resolved.
Show resolved Hide resolved
DLL_PUBLIC virtual void Init() = 0;
DLL_PUBLIC virtual void Run() = 0;
DLL_PUBLIC virtual void Prefetch() = 0;
Expand All @@ -50,7 +53,8 @@ class DLL_PUBLIC ExecutorBase {
DLL_PUBLIC virtual void Shutdown() = 0;
DLL_PUBLIC virtual Checkpoint& GetCurrentCheckpoint() = 0;
DLL_PUBLIC virtual void RestoreStateFromCheckpoint(const Checkpoint &cpt) = 0;
DLL_PUBLIC virtual int InputFeedCount(const std::string &input_name) = 0;
DLL_PUBLIC virtual int InputFeedCount(std::string_view input_name) = 0;
DLL_PUBLIC virtual OperatorBase *GetOperator(std::string_view name) = 0;

protected:
// virtual to allow the TestPruneWholeGraph test in gcc
Expand Down
8 changes: 7 additions & 1 deletion dali/pipeline/executor/executor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,17 @@ void Executor<WorkspacePolicy, QueuePolicy>::Run() {
}

template <typename WorkspacePolicy, typename QueuePolicy>
int Executor<WorkspacePolicy, QueuePolicy>::InputFeedCount(const std::string &op_name) {
int Executor<WorkspacePolicy, QueuePolicy>::InputFeedCount(std::string_view op_name) {
(void)graph_->Node(op_name);
return queue_sizes_.cpu_size;
}

template <typename WorkspacePolicy, typename QueuePolicy>
OperatorBase *Executor<WorkspacePolicy, QueuePolicy>::GetOperator(std::string_view op_name) {
return graph_->Node(op_name).op.get();
}


template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::Prefetch() {
int i;
Expand Down
11 changes: 8 additions & 3 deletions dali/pipeline/executor/executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
DLL_PUBLIC void EnableCheckpointing(bool checkpointing = false) override {
checkpointing_ = checkpointing;
}
DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override;
DLL_PUBLIC void Run() override;
DLL_PUBLIC void Prefetch() override;
DLL_PUBLIC void Init() override {}
Expand Down Expand Up @@ -129,9 +128,15 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
*/
DLL_PUBLIC void RestoreStateFromCheckpoint(const Checkpoint &cpt) override;

DLL_PUBLIC int InputFeedCount(const std::string &op_name) override;
DLL_PUBLIC int InputFeedCount(std::string_view op_name) override;

DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override;

DLL_PUBLIC OperatorBase *GetOperator(std::string_view instance_name) override;

protected:
OpGraph lowered_graph_;
stiepan marked this conversation as resolved.
Show resolved Hide resolved

DLL_PUBLIC virtual void RunCPU();
DLL_PUBLIC virtual void RunMixed();
DLL_PUBLIC virtual void RunGPU();
Expand Down Expand Up @@ -335,7 +340,7 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {
// true iff the graph that is executed contains if statements, set by DetectConditionals()
bool has_conditionals_ = false;

bool checkpointing_;
bool checkpointing_ = false;

private:
void RunHelper(OpNode &op_node, Workspace &ws, size_t iteration_id);
Expand Down
19 changes: 16 additions & 3 deletions dali/pipeline/executor/lowered_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ OpNode& OpGraph::PlaceNewOp(OpType op_type, const OpSpec &op_spec, std::string i
auto new_partition_id = NumOp(op_type);
node.partition_index = new_partition_id;
op_partitions_[static_cast<int>(op_type)].push_back(node.id);
op_name_to_id_[instance_name] = node.id;
return node;
}

Expand All @@ -113,7 +114,7 @@ TensorNode& OpGraph::PlaceNewTensor() {
}


void OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) {
OpNode &OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary in a follow-up PR.

// Validate the op specification
CheckOpConstraints(spec);

Expand Down Expand Up @@ -202,6 +203,8 @@ void OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) {
", but output with this name already exists as output of op '",
this->Node(TensorSourceID(name)).instance_name, "'"));
}

return new_node;
}

void OpGraph::InstantiateOperators() {
Expand Down Expand Up @@ -269,6 +272,9 @@ void OpGraph::RemoveTensorNode(TensorNodeId id) {
void OpGraph::SwapOpNodes(OpNodeId left_id, OpNodeId right_id) {
auto &left = op_nodes_[left_id];
auto &right = op_nodes_[right_id];

op_name_to_id_[left.instance_name] = right_id;
op_name_to_id_[right.instance_name] = left_id;
// Swap all references in tensor edges
// Produced tensors (children)
{
Expand Down Expand Up @@ -351,6 +357,7 @@ void OpGraph::RemoveOpNode(OpNodeId id) {
Node(parent_id).children.erase(op_nodes_.back().id);
}
// assume that we removed one element
szkarpinski marked this conversation as resolved.
Show resolved Hide resolved
op_name_to_id_.erase(target_op.instance_name);
op_nodes_.pop_back();
}

Expand Down Expand Up @@ -434,13 +441,19 @@ std::vector<std::vector<TensorNodeId>> OpGraph::PartitionTensorByOpType() const
}

// TODO(klecki): get rid of string indexing
OpNode& OpGraph::Node(const std::string& name) {
OpNode& OpGraph::Node(std::string_view name) {
auto it = op_name_to_id_.find(name);
if (it != op_name_to_id_.end()) {
OpNodeId id = it->second;
assert(id >= 0 && id < OpNodeId(op_nodes_.size()));
return op_nodes_[id];
}
for (auto &node : op_nodes_) {
if (node.instance_name == name) {
return node;
}
}
DALI_FAIL("Operator node with name " + name + " not found.");
DALI_FAIL(make_string("Operator node with name ", name, " not found."));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For very dubious reasons, C++ doesn't define operator+ for string_view.

}

namespace {
Expand Down
19 changes: 15 additions & 4 deletions dali/pipeline/executor/lowered_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef DALI_PIPELINE_EXECUTOR_LOWERED_GRAPH_H_
#define DALI_PIPELINE_EXECUTOR_LOWERED_GRAPH_H_

#include <functional>
#include <map>
#include <unordered_set>
#include <utility>
Expand All @@ -27,6 +28,7 @@
#include "dali/core/common.h"
#include "dali/core/error_handling.h"
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/graph/op_graph2.h"

namespace dali {

Expand Down Expand Up @@ -59,6 +61,8 @@ struct OpNode {
virtual ~OpNode() = default;
OpNode& operator=(const OpNode&) = delete;

const graph::OpNode *definition = nullptr;

OpNode(OpNode &&) = default;
OpNode& operator=(OpNode &&) = default;

Expand All @@ -69,6 +73,7 @@ struct OpNode {

std::unique_ptr<OperatorBase> op;
OpNodeId id = -1;
// TODO(michalz): Consider removing the (now) redundant fields and use the definition
OpSpec spec;
std::set<OpNodeId> parents, children;

Expand All @@ -77,6 +82,7 @@ struct OpNode {
// To reduce number of allocation of shapes in Setup
std::vector<OutputDesc> output_desc;

// TODO(michalz): Consider removing the (now) redundant fields and use the definition
std::string instance_name;
OpType op_type = OpType::COUNT;
OpPartitionId partition_index = -1;
Expand All @@ -95,8 +101,12 @@ using consumer_edge_t = TensorMeta;

// Second type of graph nodes.
struct TensorNode {
// NOTE: TensorNode doesn't define the storage device, but TensorNode is taken from OpSpec
// where it's unambiguously associated with a storage device.
const graph::DataNode *definition = nullptr;

TensorNodeId id;
std::string name; // TODO(klecki): not happy about all the strings
std::string name;
producer_edge_t producer;
// order of consumers is arbitrary
std::vector<consumer_edge_t> consumers;
Expand Down Expand Up @@ -134,7 +144,7 @@ class DLL_PUBLIC OpGraph {
/**
* @brief Adds an op with the input specification to the graph.
*/
DLL_PUBLIC void AddOp(const OpSpec &spec, const std::string& name);
DLL_PUBLIC OpNode &AddOp(const OpSpec &spec, const std::string& name);

/**
* @brief Removes the node with the specified OpNodeId from
Expand Down Expand Up @@ -188,7 +198,7 @@ class DLL_PUBLIC OpGraph {
* index as argument so should not be used in performance
* critical section of the code.
*/
DLL_PUBLIC OpNode& Node(const std::string& name);
DLL_PUBLIC OpNode& Node(std::string_view name);

/**
* @brief Returns the graph node with the given index in the graph.
Expand Down Expand Up @@ -467,8 +477,9 @@ class DLL_PUBLIC OpGraph {
*/
void RemoveOpNode(OpNodeId id);

std::map<std::string, TensorNodeId> tensor_name_to_id_;
std::map<std::string, TensorNodeId, std::less<>> tensor_name_to_id_;
Copy link
Contributor Author

@mzient mzient Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing std::less<> allows use of any comparable key for find and operator[] - in this case, a string_view.
See example - look for "transparent comparison".

mutable std::map<TensorNodeId, bool> has_consumers_in_other_stage_;
std::map<std::string, OpNodeId, std::less<>> op_name_to_id_;

bool pass_through_computed_ = false;
};
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/executor/pipelined_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ size_t PipelinedExecutorImpl<WorkspacePolicy, QueuePolicy>::CalcIterationDataSiz
this->queue_sizes_.gpu_size /* mixed_queue_size */ + 1;
}

int SeparatedPipelinedExecutor::InputFeedCount(const std::string &op_name) {
int SeparatedPipelinedExecutor::InputFeedCount(std::string_view op_name) {
(void)graph_->Node(op_name);
return queue_sizes_.cpu_size + queue_sizes_.gpu_size;
}
Expand Down
10 changes: 1 addition & 9 deletions dali/pipeline/executor/pipelined_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ class DLL_PUBLIC PipelinedExecutorImpl : public Executor<WorkspacePolicy, QueueP

DLL_PUBLIC ~PipelinedExecutorImpl() override = default;

DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override;

DISABLE_COPY_MOVE_ASSIGN(PipelinedExecutorImpl);

protected:
Expand Down Expand Up @@ -83,12 +81,6 @@ class DLL_PUBLIC PipelinedExecutorImpl : public Executor<WorkspacePolicy, QueueP
size_t CalcIterationDataSize() const override;
};

template <typename WorkspacePolicy, typename QueuePolicy>
void PipelinedExecutorImpl<WorkspacePolicy, QueuePolicy>::Build(OpGraph *graph,
vector<string> output_names) {
Executor<WorkspacePolicy, QueuePolicy>::Build(graph, output_names);
}
Comment on lines -86 to -90
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless override removed.


template <typename WorkspacePolicy, typename QueuePolicy>
void PipelinedExecutorImpl<WorkspacePolicy, QueuePolicy>::SetupOutputInfo(OpGraph &graph) {
DeviceGuard g(device_id_);
Expand Down Expand Up @@ -123,7 +115,7 @@ class DLL_PUBLIC SeparatedPipelinedExecutor
using ImplBase = PipelinedExecutorImpl<AOT_WS_Policy<SeparateQueuePolicy>, SeparateQueuePolicy>;
using ImplBase::ImplBase;
public:
int InputFeedCount(const std::string &name) override;
int InputFeedCount(std::string_view name) override;
};

} // namespace dali
Expand Down
45 changes: 30 additions & 15 deletions dali/pipeline/graph/op_graph2.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,24 @@ class DLL_PUBLIC OpGraph {
return data_nodes_;
}

/** Returns an OpNode with a matching instance name or nullptr. */
OpNode *GetOp(std::string_view instance_name) {
auto it = name2op_.find(instance_name);
if (it != name2op_.end())
return &*it->second;
else
return nullptr;
return GetOpImpl(instance_name);
}

/** Returns a DataNode with a matching name or nullptr.
*
* @param data_node_name
* @return DataNode*
*/
/** Returns an OpNode with a matching instance name or nullptr. */
const OpNode *GetOp(std::string_view instance_name) const {
return GetOpImpl(instance_name);
}

/** Returns a DataNode with a matching name or nullptr. */
DataNode *GetData(std::string_view data_node_name) {
auto it = name2data_.find(data_node_name);
if (it != name2data_.end())
return &*it->second;
else
return nullptr;
return GetDataImpl(data_node_name);
}

/** Returns a DataNode with a matching name or nullptr. */
const DataNode *GetData(std::string_view data_node_name) const {
return GetDataImpl(data_node_name);
}

/** Sorts the graph topologically and removes entries that do not contribute to essential nodes.
Expand Down Expand Up @@ -200,6 +199,22 @@ class DLL_PUBLIC OpGraph {
}

private:
OpNode *GetOpImpl(std::string_view instance_name) const {
auto it = name2op_.find(instance_name);
if (it != name2op_.end())
return &*it->second;
else
return nullptr;
}

DataNode *GetDataImpl(std::string_view data_node_name) const {
auto it = name2data_.find(data_node_name);
if (it != name2data_.end())
return &*it->second;
else
return nullptr;
}

void RemoveDataNodeReferences(OpNode &op);

OpNodeList op_nodes_;
Expand Down
Loading
Loading