Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Jun 3, 2024
1 parent 1c30e3c commit cadf3bc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
6 changes: 4 additions & 2 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "dali/core/common.h"
#include "dali/pipeline/workspace/workspace.h"
#include "dali/pipeline/operator/checkpointing/checkpoint.h"
#include "dali/pipeline/graph/op_graph2.h"

namespace dali {

Expand All @@ -37,7 +38,7 @@ using ExecutorMetaMap = std::unordered_map<std::string, std::vector<ExecutorMeta
class DLL_PUBLIC ExecutorBase {
public:
DLL_PUBLIC virtual ~ExecutorBase() {}
DLL_PUBLIC virtual void Build(OpGraph *graph, vector<string> output_names) = 0;
DLL_PUBLIC virtual void Build(const graph::OpGraph &graph) = 0;
DLL_PUBLIC virtual void Init() = 0;
DLL_PUBLIC virtual void Run() = 0;
DLL_PUBLIC virtual void Prefetch() = 0;
Expand All @@ -50,7 +51,8 @@ class DLL_PUBLIC ExecutorBase {
DLL_PUBLIC virtual void Shutdown() = 0;
DLL_PUBLIC virtual Checkpoint& GetCurrentCheckpoint() = 0;
DLL_PUBLIC virtual void RestoreStateFromCheckpoint(const Checkpoint &cpt) = 0;
DLL_PUBLIC virtual int InputFeedCount(const std::string &input_name) = 0;
DLL_PUBLIC virtual int InputFeedCount(std::string_view input_name) = 0;
DLL_PUBLIC virtual OperatorBase *GetOperator(std::string_view name) = 0;

protected:
// virtual to allow the TestPruneWholeGraph test in gcc
Expand Down
9 changes: 4 additions & 5 deletions dali/pipeline/executor/lowered_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void CheckOpConstraints(const OpSpec &spec) {
" outputs, but was passed ", spec.NumOutput(), "."));
}

OpType ParseOpType(const std::string &device) {
OpType ParseOpType(std::string_view device) {
if (device == "gpu") {
return OpType::GPU;
} else if (device == "cpu") {
Expand All @@ -84,7 +84,7 @@ OpType ParseOpType(const std::string &device) {
DALI_FAIL("Unsupported device type: " + device + ".");
}

StorageDevice ParseStorageDevice(const std::string &io_device) {
StorageDevice ParseStorageDevice(std::string_view io_device) {
if (io_device == "cpu") {
return StorageDevice::CPU;
}
Expand Down Expand Up @@ -126,7 +126,7 @@ TensorNode& OpGraph::PlaceNewTensor() {
}


OpNode &OpGraph::AddOp(const OpSpec &spec, const std::string &op_name) {
OpNode &OpGraph::AddOp(const OpSpec &spec, std::string_view op_name) {
// Validate the op specification
CheckOpConstraints(spec);

Expand Down Expand Up @@ -448,8 +448,7 @@ std::vector<std::vector<TensorNodeId>> OpGraph::PartitionTensorByOpType() const
return out;
}

// TODO(klecki): get rid of string indexing
OpNode& OpGraph::Node(const std::string& name) {
OpNode& OpGraph::Node(std::string_view name) {
for (auto &node : op_nodes_) {
if (node.instance_name == name) {
return node;
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/executor/lowered_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class DLL_PUBLIC OpGraph {
/**
* @brief Adds an op with the input specification to the graph.
*/
DLL_PUBLIC OpNode &AddOp(const OpSpec &spec, const std::string& name);
DLL_PUBLIC OpNode &AddOp(const OpSpec &spec, std::string_view name);

/**
* @brief Removes the node with the specified OpNodeId from
Expand Down Expand Up @@ -199,7 +199,7 @@ class DLL_PUBLIC OpGraph {
* index as argument so should not be used in performance
* critical section of the code.
*/
DLL_PUBLIC OpNode& Node(const std::string& name);
DLL_PUBLIC OpNode& Node(std::string_view name);

/**
* @brief Returns the graph node with the given index in the graph.
Expand Down

0 comments on commit cadf3bc

Please sign in to comment.