-
Notifications
You must be signed in to change notification settings - Fork 630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring in Pipeline, old Executor + name lookup improvement in old OpGraph. #5495
Changes from all commits
639daf3
e0ed7d2
1c30e3c
111cf63
c903053
7bbde44
b8228c3
088440b
3a0283a
e371e38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#ifndef DALI_PIPELINE_EXECUTOR_LOWERED_GRAPH_H_ | ||
#define DALI_PIPELINE_EXECUTOR_LOWERED_GRAPH_H_ | ||
|
||
#include <functional> | ||
#include <map> | ||
#include <unordered_set> | ||
#include <utility> | ||
|
@@ -27,6 +28,7 @@ | |
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
#include "dali/pipeline/operator/operator.h" | ||
#include "dali/pipeline/graph/op_graph2.h" | ||
|
||
namespace dali { | ||
|
||
|
@@ -96,7 +98,7 @@ using consumer_edge_t = TensorMeta; | |
// Second type of graph nodes. | ||
struct TensorNode { | ||
TensorNodeId id; | ||
std::string name; // TODO(klecki): not happy about all the strings | ||
std::string name; | ||
producer_edge_t producer; | ||
// order of consumers is arbitrary | ||
std::vector<consumer_edge_t> consumers; | ||
|
@@ -134,7 +136,7 @@ class DLL_PUBLIC OpGraph { | |
/** | ||
* @brief Adds an op with the input specification to the graph. | ||
*/ | ||
DLL_PUBLIC void AddOp(const OpSpec &spec, const std::string& name); | ||
DLL_PUBLIC OpNode &AddOp(const OpSpec &spec, const std::string& name); | ||
|
||
/** | ||
* @brief Removes the node with the specified OpNodeId from | ||
|
@@ -183,12 +185,21 @@ class DLL_PUBLIC OpGraph { | |
} | ||
|
||
/** | ||
* @brief Returns the graph node with the given name. | ||
* This function is much slower than the version taking | ||
* index as argument so should not be used in performance | ||
* critical section of the code. | ||
* @brief Returns the graph node with the given name or nullptr, if not found. | ||
*/ | ||
DLL_PUBLIC OpNode& Node(const std::string& name); | ||
DLL_PUBLIC OpNode *NodePtr(std::string_view instance_name); | ||
|
||
DLL_PUBLIC OpNode &Node(std::string_view instance_name) { | ||
OpNode *node = NodePtr(instance_name); | ||
if (!node) | ||
DALI_FAIL(make_string("Operator node with name \"", instance_name, "\" not found.")); | ||
return *node; | ||
} | ||
|
||
/** | ||
* @brief Returns the id of the data node with given name or nullopt, if not found. | ||
*/ | ||
DLL_PUBLIC std::optional<OpNodeId> NodeId(std::string_view instance_name); | ||
|
||
/** | ||
* @brief Returns the graph node with the given index in the graph. | ||
|
@@ -216,20 +227,31 @@ class DLL_PUBLIC OpGraph { | |
return tensor_nodes_[id]; | ||
} | ||
|
||
DLL_PUBLIC TensorNodeId TensorId(const std::string& name) const { | ||
DLL_PUBLIC std::optional<TensorNodeId> TensorId(std::string_view name) const { | ||
auto it = tensor_name_to_id_.find(name); | ||
DALI_ENFORCE(it != tensor_name_to_id_.end(), | ||
"Tensor with name " + name + " does not exist in graph."); | ||
if (it == tensor_name_to_id_.end()) | ||
return std::nullopt; | ||
return it->second; | ||
} | ||
|
||
/** | ||
* @brief Returns the Tensor node with the given name. | ||
* @brief Returns the Tensor node with the given name or nullptr, if not found. | ||
*/ | ||
DLL_PUBLIC const TensorNode& Tensor(const std::string& name) const { | ||
return tensor_nodes_[TensorId(name)]; | ||
DLL_PUBLIC const TensorNode *TensorPtr(std::string_view name) const { | ||
auto id = TensorId(name); | ||
if (!id) | ||
return nullptr; | ||
return &Tensor(*id); | ||
} | ||
|
||
DLL_PUBLIC const TensorNode &Tensor(std::string_view name) const { | ||
auto *t = TensorPtr(name); | ||
if (!t) | ||
DALI_FAIL(make_string("Tensor with name \"", name, "\" not found")); | ||
return *t; | ||
} | ||
|
||
|
||
DLL_PUBLIC std::vector<std::vector<TensorNodeId>> PartitionTensorByOpType() const; | ||
|
||
/** | ||
|
@@ -467,8 +489,9 @@ class DLL_PUBLIC OpGraph { | |
*/ | ||
void RemoveOpNode(OpNodeId id); | ||
|
||
std::map<std::string, TensorNodeId> tensor_name_to_id_; | ||
std::map<std::string, TensorNodeId, std::less<>> tensor_name_to_id_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing |
||
mutable std::map<TensorNodeId, bool> has_consumers_in_other_stage_; | ||
std::map<std::string, OpNodeId, std::less<>> op_name_to_id_; | ||
|
||
bool pass_through_computed_ = false; | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,8 +50,6 @@ class DLL_PUBLIC PipelinedExecutorImpl : public Executor<WorkspacePolicy, QueueP | |
|
||
DLL_PUBLIC ~PipelinedExecutorImpl() override = default; | ||
|
||
DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override; | ||
|
||
DISABLE_COPY_MOVE_ASSIGN(PipelinedExecutorImpl); | ||
|
||
protected: | ||
|
@@ -83,12 +81,6 @@ class DLL_PUBLIC PipelinedExecutorImpl : public Executor<WorkspacePolicy, QueueP | |
size_t CalcIterationDataSize() const override; | ||
}; | ||
|
||
template <typename WorkspacePolicy, typename QueuePolicy> | ||
void PipelinedExecutorImpl<WorkspacePolicy, QueuePolicy>::Build(OpGraph *graph, | ||
vector<string> output_names) { | ||
Executor<WorkspacePolicy, QueuePolicy>::Build(graph, output_names); | ||
} | ||
Comment on lines
-86
to
-90
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Useless override removed. |
||
|
||
template <typename WorkspacePolicy, typename QueuePolicy> | ||
void PipelinedExecutorImpl<WorkspacePolicy, QueuePolicy>::SetupOutputInfo(OpGraph &graph) { | ||
DeviceGuard g(device_id_); | ||
|
@@ -123,7 +115,7 @@ class DLL_PUBLIC SeparatedPipelinedExecutor | |
using ImplBase = PipelinedExecutorImpl<AOT_WS_Policy<SeparateQueuePolicy>, SeparateQueuePolicy>; | ||
using ImplBase::ImplBase; | ||
public: | ||
int InputFeedCount(const std::string &name) override; | ||
int InputFeedCount(std::string_view name) override; | ||
}; | ||
|
||
} // namespace dali | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary in a follow-up PR.