Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@ class depends_on : public ::sycl::detail::PropertyWithData<
} // namespace node
} // namespace property

/// Graph in the modifiable state.
template <graph_state State = graph_state::modifiable>
class __SYCL_EXPORT command_graph {
template <graph_state State> class command_graph;

namespace detail {
// Templateless modifiable command-graph base class.
class __SYCL_EXPORT modifiable_command_graph {
public:
/// Constructor.
/// @param SyclContext Context to use for graph.
/// @param SyclDevice Device all nodes will be associated with.
/// @param PropList Optional list of properties to pass.
command_graph(const context &SyclContext, const device &SyclDevice,
const property_list &PropList = {});
modifiable_command_graph(const context &SyclContext, const device &SyclDevice,
const property_list &PropList = {});

/// Add an empty node to the graph.
/// @param PropList Property list used to pass [0..n] predecessor nodes.
Expand Down Expand Up @@ -166,10 +168,11 @@ class __SYCL_EXPORT command_graph {
/// executing.
bool end_recording(const std::vector<queue> &RecordingQueues);

private:
protected:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
command_graph(const std::shared_ptr<detail::graph_impl> &Impl) : impl(Impl) {}
modifiable_command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
: impl(Impl) {}

/// Template-less implementation of add() for CGF nodes.
/// @param CGF Command-group function to add.
Expand All @@ -192,21 +195,22 @@ class __SYCL_EXPORT command_graph {
std::shared_ptr<detail::graph_impl> impl;
};

template <> class __SYCL_EXPORT command_graph<graph_state::executable> {
// Templateless executable command-graph base class.
class __SYCL_EXPORT executable_command_graph {
public:
/// An executable command-graph is not user constructable.
command_graph() = delete;
executable_command_graph() = delete;

/// Update the inputs & output of the graph.
/// @param Graph Graph to use the inputs and outputs of.
void update(const command_graph<graph_state::modifiable> &Graph);

private:
protected:
/// Constructor used by internal runtime.
/// @param Graph Detail implementation class to construct with.
/// @param Ctx Context to use for graph.
command_graph(const std::shared_ptr<detail::graph_impl> &Graph,
const sycl::context &Ctx);
executable_command_graph(const std::shared_ptr<detail::graph_impl> &Graph,
const sycl::context &Ctx);

template <class Obj>
friend decltype(Obj::impl)
Expand All @@ -218,7 +222,34 @@ template <> class __SYCL_EXPORT command_graph<graph_state::executable> {
int MTag;
std::shared_ptr<detail::exec_graph_impl> impl;

friend class command_graph<graph_state::modifiable>;
friend class modifiable_command_graph;
};
} // namespace detail

/// Graph in the modifiable state.
template <graph_state State = graph_state::modifiable>
class command_graph : public detail::modifiable_command_graph {
public:
/// Constructor.
/// @param SyclContext Context to use for graph.
/// @param SyclDevice Device all nodes will be associated with.
/// @param PropList Optional list of properties to pass.
command_graph(const context &SyclContext, const device &SyclDevice,
const property_list &PropList = {})
: modifiable_command_graph(SyclContext, SyclDevice, PropList) {}

private:
/// Constructor used internally by the runtime.
/// @param Impl Detail implementation class to construct object with.
command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
: modifiable_command_graph(Impl) {}
};

template <>
class command_graph<graph_state::executable>
: public detail::executable_command_graph {
private:
using detail::executable_command_graph::executable_command_graph;
};

/// Additional CTAD deduction guide.
Expand Down
45 changes: 15 additions & 30 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,13 @@ sycl::event exec_graph_impl::enqueue(
sycl::detail::createSyclObjFromImpl<sycl::event>(NewEvent);
return QueueEvent;
}
} // namespace detail

template <>
command_graph<graph_state::modifiable>::command_graph(
modifiable_command_graph::modifiable_command_graph(
const sycl::context &SyclContext, const sycl::device &SyclDevice,
const sycl::property_list &)
: impl(std::make_shared<detail::graph_impl>(SyclContext, SyclDevice)) {}

template <>
node command_graph<graph_state::modifiable>::addImpl(
const std::vector<node> &Deps) {
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
Expand All @@ -295,9 +291,8 @@ node command_graph<graph_state::modifiable>::addImpl(
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

template <>
node command_graph<graph_state::modifiable>::addImpl(
std::function<void(handler &)> CGF, const std::vector<node> &Deps) {
node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
const std::vector<node> &Deps) {
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
for (auto &D : Deps) {
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
Expand All @@ -308,8 +303,7 @@ node command_graph<graph_state::modifiable>::addImpl(
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
}

template <>
void command_graph<graph_state::modifiable>::make_edge(node &Src, node &Dest) {
void modifiable_command_graph::make_edge(node &Src, node &Dest) {
std::shared_ptr<detail::node_impl> SenderImpl =
sycl::detail::getSyclObjImpl(Src);
std::shared_ptr<detail::node_impl> ReceiverImpl =
Expand All @@ -320,17 +314,13 @@ void command_graph<graph_state::modifiable>::make_edge(node &Src, node &Dest) {
impl->removeRoot(ReceiverImpl); // remove receiver from root node list
}

template <>
command_graph<graph_state::executable>
command_graph<graph_state::modifiable>::finalize(
const sycl::property_list &) const {
modifiable_command_graph::finalize(const sycl::property_list &) const {
return command_graph<graph_state::executable>{this->impl,
this->impl->getContext()};
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
queue &RecordingQueue) {
bool modifiable_command_graph::begin_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl->getCommandGraph() == nullptr) {
QueueImpl->setCommandGraph(impl);
Expand All @@ -347,8 +337,7 @@ bool command_graph<graph_state::modifiable>::begin_recording(
return false;
}

template <>
bool command_graph<graph_state::modifiable>::begin_recording(
bool modifiable_command_graph::begin_recording(
const std::vector<queue> &RecordingQueues) {
bool QueueStateChanged = false;
for (queue Queue : RecordingQueues) {
Expand All @@ -357,13 +346,9 @@ bool command_graph<graph_state::modifiable>::begin_recording(
return QueueStateChanged;
}

template <> bool command_graph<graph_state::modifiable>::end_recording() {
return impl->clearQueues();
}
bool modifiable_command_graph::end_recording() { return impl->clearQueues(); }

template <>
bool command_graph<graph_state::modifiable>::end_recording(
queue &RecordingQueue) {
bool modifiable_command_graph::end_recording(queue &RecordingQueue) {
auto QueueImpl = sycl::detail::getSyclObjImpl(RecordingQueue);
if (QueueImpl->getCommandGraph() == impl) {
QueueImpl->setCommandGraph(nullptr);
Expand All @@ -380,8 +365,7 @@ bool command_graph<graph_state::modifiable>::end_recording(
return false;
}

template <>
bool command_graph<graph_state::modifiable>::end_recording(
bool modifiable_command_graph::end_recording(
const std::vector<queue> &RecordingQueues) {
bool QueueStateChanged = false;
for (queue Queue : RecordingQueues) {
Expand All @@ -390,25 +374,26 @@ bool command_graph<graph_state::modifiable>::end_recording(
return QueueStateChanged;
}

command_graph<graph_state::executable>::command_graph(
executable_command_graph::executable_command_graph(
const std::shared_ptr<detail::graph_impl> &Graph, const sycl::context &Ctx)
: MTag(rand()),
impl(std::make_shared<detail::exec_graph_impl>(Ctx, Graph)) {
finalizeImpl(); // Create backend representation for executable graph
}

void command_graph<graph_state::executable>::finalizeImpl() {
void executable_command_graph::finalizeImpl() {
// Create PI command-buffers for each device in the finalized context
impl->schedule();
}

void command_graph<graph_state::executable>::update(
void executable_command_graph::update(
const command_graph<graph_state::modifiable> &Graph) {
(void)Graph;
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Method not yet implemented");
}

} // namespace detail
} // namespace experimental
} // namespace oneapi
} // namespace ext
Expand Down
30 changes: 15 additions & 15 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3661,20 +3661,20 @@ _ZN4sycl3_V13ext6oneapi10level_zero10make_queueERKNS0_7contextERKNS0_6deviceEmbb
_ZN4sycl3_V13ext6oneapi10level_zero11make_deviceERKNS0_8platformEm
_ZN4sycl3_V13ext6oneapi10level_zero12make_contextERKSt6vectorINS0_6deviceESaIS5_EEmb
_ZN4sycl3_V13ext6oneapi10level_zero13make_platformEm
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE13end_recordingERKSt6vectorINS0_5queueESaIS8_EE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE13end_recordingERNS0_5queueE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE13end_recordingEv
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE15begin_recordingERKSt6vectorINS0_5queueESaIS8_EE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE15begin_recordingERNS0_5queueE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE7addImplERKSt6vectorINS3_4nodeESaIS8_EE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE7addImplESt8functionIFvRNS0_7handlerEEERKSt6vectorINS3_4nodeESaISD_EE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE9make_edgeERNS3_4nodeES8_
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EEC1ERKNS0_7contextERKNS0_6deviceERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EEC2ERKNS0_7contextERKNS0_6deviceERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE1EE12finalizeImplEv
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE1EE6updateERKNS4_ILS5_0EEE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE1EEC1ERKSt10shared_ptrINS3_6detail10graph_implEERKNS0_7contextE
_ZN4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE1EEC2ERKSt10shared_ptrINS3_6detail10graph_implEERKNS0_7contextE
_ZN4sycl3_V13ext6oneapi12experimental6detail24executable_command_graph12finalizeImplEv
_ZN4sycl3_V13ext6oneapi12experimental6detail24executable_command_graph6updateERKNS3_13command_graphILNS3_11graph_stateE0EEE
_ZN4sycl3_V13ext6oneapi12experimental6detail24executable_command_graphC1ERKSt10shared_ptrINS4_10graph_implEERKNS0_7contextE
_ZN4sycl3_V13ext6oneapi12experimental6detail24executable_command_graphC2ERKSt10shared_ptrINS4_10graph_implEERKNS0_7contextE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph13end_recordingERKSt6vectorINS0_5queueESaIS7_EE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph13end_recordingERNS0_5queueE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph13end_recordingEv
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph15begin_recordingERKSt6vectorINS0_5queueESaIS7_EE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph15begin_recordingERNS0_5queueE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph7addImplERKSt6vectorINS3_4nodeESaIS7_EE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph7addImplESt8functionIFvRNS0_7handlerEEERKSt6vectorINS3_4nodeESaISC_EE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph9make_edgeERNS3_4nodeES7_
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graphC1ERKNS0_7contextERKNS0_6deviceERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graphC2ERKNS0_7contextERKNS0_6deviceERKNS0_13property_listE
_ZN4sycl3_V13ext6oneapi15filter_selectorC1ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN4sycl3_V13ext6oneapi15filter_selectorC2ERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE
_ZN4sycl3_V13ext8codeplay12experimental14fusion_wrapper12start_fusionEv
Expand Down Expand Up @@ -4098,7 +4098,7 @@ _ZNK4sycl3_V115interop_handler12GetNativeMemEPNS0_6detail16AccessorImplHostE
_ZNK4sycl3_V115interop_handler14GetNativeQueueERi
_ZNK4sycl3_V116default_selectorclERKNS0_6deviceE
_ZNK4sycl3_V120accelerator_selectorclERKNS0_6deviceE
_ZNK4sycl3_V13ext6oneapi12experimental13command_graphILNS3_11graph_stateE0EE8finalizeERKNS0_13property_listE
_ZNK4sycl3_V13ext6oneapi12experimental6detail24modifiable_command_graph8finalizeERKNS0_13property_listE
_ZNK4sycl3_V13ext6oneapi15filter_selector13select_deviceEv
_ZNK4sycl3_V13ext6oneapi15filter_selector5resetEv
_ZNK4sycl3_V13ext6oneapi15filter_selectorclERKNS0_6deviceE
Expand Down
Loading