From 21e2024bb23bf253d4b480c217f6d651c3bb4e74 Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Thu, 26 Jan 2023 14:47:16 -0800 Subject: [PATCH 1/6] pipe names to dml graphs --- .../src/DmlGraphFusionHelper.cpp | 3 +- .../src/GraphDescBuilder.cpp | 1 + .../src/GraphDescBuilder.h | 1 + .../src/MLOperatorAuthorImpl.cpp | 79 +++++++++++++++++++ .../src/MLOperatorAuthorImpl.h | 8 +- .../src/Operators/DmlOperator.cpp | 5 ++ .../MLOperatorAuthorHelper.h | 7 +- .../MLOperatorAuthorPrivate.h | 24 ++++++ 8 files changed, 125 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 9d0ba9dc7ea51..17aa197396ae0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -292,7 +292,8 @@ namespace DmlGraphFusionHelper { for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{graphDesc.nodes[i].op.Get()}; + auto& nodeInfo = graphDesc.nodes[i]; + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 6b9230657bf24..636f46428ce99 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -361,6 +361,7 @@ namespace Dml::GraphDescBuilder NodeInfo nodeInfo = {}; nodeInfo.op = std::move(op); + nodeInfo.name = node.Name(); graphNodes.push_back(std::move(nodeInfo)); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index bf2ccae55ed90..5c04962e55557 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -28,6 +28,7 @@ namespace Dml struct NodeInfo { Microsoft::WRL::ComPtr op; + std::string name; }; struct GraphDesc diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 197c62283fba9..af8eba647c3d6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -982,6 +982,85 @@ namespace Windows::AI::MachineLearning::Adapter m_abiExecutionObject.CopyTo(executionInterface); } + uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8NameSizeInBytes() const noexcept + { + // Include null terminator. + const auto& name = m_impl->node().Name(); + return name.empty() ? 0 : name.size() + 1; + } + + HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8Name(uint32_t bufferSizeInBytes, char* outputName) const noexcept + { + if (bufferSizeInBytes == 0) + { + return E_INVALIDARG; + } + + // Copy as many characters as possible, leaving room for the null terminator. + const auto& nodeName = m_impl->node().Name(); + size_t charsCopied = nodeName.copy(outputName, bufferSizeInBytes - 1); + + // Write the null terminator. + assert(charsCopied >= 0 && charsCopied < bufferSizeInBytes); + outputName[charsCopied] = '\0'; + + return S_OK; + } + + uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideNameSizeInBytes() const noexcept + { + const auto& name = m_impl->node().Name(); + if (name.empty()) + { + return 0; + } + + int requiredSize = MultiByteToWideChar(CP_UTF8, 0, name.data(), name.size(), nullptr, 0); + assert(requiredSize > 0); + + // Include null terminator. + return static_cast(requiredSize) + 1; + } + + HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideName(uint32_t bufferSizeInBytes, wchar_t* outputName) const noexcept + { + // Buffer needs to be large enough to at least hold a null terminator. + if (bufferSizeInBytes < sizeof(wchar_t)) + { + return E_INVALIDARG; + } + + const auto& nodeName = m_impl->node().Name(); + if (nodeName.empty()) + { + outputName[0] = L'\0'; + return S_OK; + } + + uint32_t bufferSizeInChars = bufferSizeInBytes / sizeof(wchar_t); + int charsCopiedIfSucceeded = MultiByteToWideChar(CP_UTF8, 0, nodeName.data(), nodeName.size(), outputName, bufferSizeInChars); + + auto lastError = GetLastError(); + if (lastError == ERROR_INSUFFICIENT_BUFFER) + { + // Buffer was too small. Truncate and overwrite last char with null terminator. + outputName[bufferSizeInChars - 1] = '\0'; + return S_OK; + } + + if (charsCopiedIfSucceeded == 0) + { + assert(lastError == ERROR_INVALID_PARAMETER || lastError == ERROR_NO_UNICODE_TRANSLATION); + return E_INVALIDARG; + } + + // All characters copied successfully. Write null terminator at the end of copied chars. + assert(lastError == 0); + outputName[charsCopiedIfSucceeded] = '\0'; + + return S_OK; + } + template uint32_t STDMETHODCALLTYPE OpNodeInfoWrapper::GetInputCount() const noexcept { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index dd1b743587ab5..7ba3a3a7a6cba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -331,7 +331,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable class OpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< - Microsoft::WRL::ChainInterfaces, + Microsoft::WRL::ChainInterfaces, IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, onnxruntime::null_type> { @@ -374,6 +374,12 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< return E_NOTIMPL; } + uint32_t STDMETHODCALLTYPE GetUtf8NameSizeInBytes() const noexcept override; + HRESULT STDMETHODCALLTYPE GetUtf8Name(uint32_t bufferSizeInBytes, char* name) const noexcept override; + + uint32_t STDMETHODCALLTYPE GetWideNameSizeInBytes() const noexcept override; + HRESULT STDMETHODCALLTYPE GetWideName(uint32_t bufferSizeInBytes, wchar_t* name) const noexcept override; + private: // For shape info, in addition to the info const EdgeShapes* m_inferredOutputShapes = nullptr; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 3ae29629efbcd..911e043ff6462 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -89,6 +89,11 @@ namespace Dml { DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags(); ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator))); + + // Static buffer (might truncate name) to avoid excessive dynamic allocation only for debugging purposes. + wchar_t nodeName[512]; + ORT_THROW_IF_FAILED(kernelInfo.GetInterfacePrivate()->GetWideName(sizeof(nodeName), nodeName)); + ORT_THROW_IF_FAILED(m_compiledOperator->SetName(nodeName)); UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index d79b2fb4e7c2a..4f87871688323 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -494,6 +494,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return m_impl; } + Microsoft::WRL::ComPtr GetInterfacePrivate() const noexcept + { + return m_implPrivate; + } + Microsoft::WRL::ComPtr GetExecutionInterface() const noexcept { Microsoft::WRL::ComPtr ret; @@ -556,7 +561,7 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes private: Microsoft::WRL::ComPtr m_impl; - Microsoft::WRL::ComPtr m_implPrivate; + Microsoft::WRL::ComPtr m_implPrivate; }; class MLShapeInferenceContext : public MLOperatorAttributes diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index afcfc587ecade..18a039b7a4c70 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -57,6 +57,30 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex ) const noexcept PURE; }; +interface __declspec(uuid("1d2e1226-a918-4236-8775-175cf1f52c9a")) +IMLOperatorKernelCreationContextPrivate1 : public IMLOperatorKernelCreationContextPrivate +{ + //! Gets the minimum size of a char buffer to store the node name (including null terminator). + //! Returns 0 if the node has no name. + STDMETHOD_(uint32_t, GetUtf8NameSizeInBytes)() const noexcept PURE; + + //! Writes the node name and null terminator into a char buffer. + STDMETHOD(GetUtf8Name)( + uint32_t bufferSizeInBytes, + _Out_writes_(bufferSizeInBytes) char* name + ) const noexcept PURE; + + //! Gets the minimum size of a wchar buffer to store the node name (including null terminator). + //! Returns 0 if the node has no name. + STDMETHOD_(uint32_t, GetWideNameSizeInBytes)() const noexcept PURE; + + //! Writes the node name and null terminator into a wchar buffer. + STDMETHOD(GetWideName)( + uint32_t bufferSizeInBytes, + _Out_writes_(bufferSizeInBytes) wchar_t* name + ) const noexcept PURE; +}; + //! \interface IMLOperatorAttributes1 //! \brief Represents the values of an operator's attributes, as determined by a model using the operator. //! This interface is called by implementations of custom operator kernels, and by implementations From 0b3971c7c1fba52972dabd84a00343d00c2f333e Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Mon, 30 Jan 2023 16:48:21 -0800 Subject: [PATCH 2/6] cr feedback --- .../src/MLOperatorAuthorImpl.cpp | 39 ++++++++++--------- .../src/MLOperatorAuthorImpl.h | 2 +- .../MLOperatorAuthorHelper.h | 6 +-- .../MLOperatorAuthorPrivate.h | 8 +--- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index af8eba647c3d6..86d29f8106c79 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -985,8 +985,7 @@ namespace Windows::AI::MachineLearning::Adapter uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8NameSizeInBytes() const noexcept { // Include null terminator. - const auto& name = m_impl->node().Name(); - return name.empty() ? 0 : name.size() + 1; + return static_cast(m_impl->node().Name().size() + 1); } HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8Name(uint32_t bufferSizeInBytes, char* outputName) const noexcept @@ -1012,14 +1011,15 @@ namespace Windows::AI::MachineLearning::Adapter const auto& name = m_impl->node().Name(); if (name.empty()) { - return 0; + // Include null terminator. + return sizeof(wchar_t); } - int requiredSize = MultiByteToWideChar(CP_UTF8, 0, name.data(), name.size(), nullptr, 0); - assert(requiredSize > 0); + int requiredSizeInChars = MultiByteToWideChar(CP_UTF8, 0, name.data(), name.size(), nullptr, 0); + assert(requiredSizeInChars > 0); // Include null terminator. - return static_cast(requiredSize) + 1; + return static_cast((requiredSizeInChars + 1) * sizeof(wchar_t)); } HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideName(uint32_t bufferSizeInBytes, wchar_t* outputName) const noexcept @@ -1040,25 +1040,28 @@ namespace Windows::AI::MachineLearning::Adapter uint32_t bufferSizeInChars = bufferSizeInBytes / sizeof(wchar_t); int charsCopiedIfSucceeded = MultiByteToWideChar(CP_UTF8, 0, nodeName.data(), nodeName.size(), outputName, bufferSizeInChars); - auto lastError = GetLastError(); - if (lastError == ERROR_INSUFFICIENT_BUFFER) + if (charsCopiedIfSucceeded > 0) { - // Buffer was too small. Truncate and overwrite last char with null terminator. - outputName[bufferSizeInChars - 1] = '\0'; + // The return value is only > 0 if ALL characters copied successfully. + // Write null terminator at the end of copied chars, which may not be at the end of the buffer. + outputName[charsCopiedIfSucceeded] = L'\0'; return S_OK; } - if (charsCopiedIfSucceeded == 0) + // An error must have occurred in MultiByteToWideChar. + assert(charsCopiedIfSucceeded <= 0); + auto lastError = GetLastError(); + + if (lastError == ERROR_INSUFFICIENT_BUFFER) { - assert(lastError == ERROR_INVALID_PARAMETER || lastError == ERROR_NO_UNICODE_TRANSLATION); - return E_INVALIDARG; + // The buffer was too small, but MultiByteToWideChar will have copied as many chars as possible. + // Truncate and overwrite last char with null terminator. Don't treat this as an error. + outputName[bufferSizeInChars - 1] = L'\0'; + return S_OK; } - // All characters copied successfully. Write null terminator at the end of copied chars. - assert(lastError == 0); - outputName[charsCopiedIfSucceeded] = '\0'; - - return S_OK; + assert(lastError == ERROR_INVALID_PARAMETER || lastError == ERROR_NO_UNICODE_TRANSLATION); + return E_INVALIDARG; } template diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 7ba3a3a7a6cba..a939855b44b39 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -331,7 +331,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable class OpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< - Microsoft::WRL::ChainInterfaces, + Microsoft::WRL::ChainInterfaces, IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, onnxruntime::null_type> { diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 4f87871688323..6d8eae839d807 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -494,9 +494,9 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return m_impl; } - Microsoft::WRL::ComPtr GetInterfacePrivate() const noexcept + IMLOperatorKernelCreationContextPrivate* GetInterfacePrivate() const noexcept { - return m_implPrivate; + return m_implPrivate.Get(); } Microsoft::WRL::ComPtr GetExecutionInterface() const noexcept @@ -561,7 +561,7 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes private: Microsoft::WRL::ComPtr m_impl; - Microsoft::WRL::ComPtr m_implPrivate; + Microsoft::WRL::ComPtr m_implPrivate; }; class MLShapeInferenceContext : public MLOperatorAttributes diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 18a039b7a4c70..d47c4aed0ea98 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -55,13 +55,9 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex STDMETHOD(SetDmlOperator)( _In_ const MLOperatorGraphDesc* operatorGraphDesc ) const noexcept PURE; -}; -interface __declspec(uuid("1d2e1226-a918-4236-8775-175cf1f52c9a")) -IMLOperatorKernelCreationContextPrivate1 : public IMLOperatorKernelCreationContextPrivate -{ //! Gets the minimum size of a char buffer to store the node name (including null terminator). - //! Returns 0 if the node has no name. + //! Returns 1 if the node has no name (calling GetUtf8Name will write a single null terminator). STDMETHOD_(uint32_t, GetUtf8NameSizeInBytes)() const noexcept PURE; //! Writes the node name and null terminator into a char buffer. @@ -71,7 +67,7 @@ IMLOperatorKernelCreationContextPrivate1 : public IMLOperatorKernelCreationConte ) const noexcept PURE; //! Gets the minimum size of a wchar buffer to store the node name (including null terminator). - //! Returns 0 if the node has no name. + //! Returns sizeof(wchar_t) if the node has no name (calling GetWideName will write a single null terminator). STDMETHOD_(uint32_t, GetWideNameSizeInBytes)() const noexcept PURE; //! Writes the node name and null terminator into a wchar buffer. From a1b20a6302ff06dbfa1c65cbe906053467e0a4d4 Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Mon, 30 Jan 2023 17:30:55 -0800 Subject: [PATCH 3/6] refactor --- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h | 4 +++- .../dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp | 2 +- .../dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h | 6 ++++-- .../dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h | 6 +++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a939855b44b39..c5f35eef776da 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -332,7 +332,7 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< Microsoft::WRL::ChainInterfaces, - IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, + IMLOperatorTensorShapeDescription, IMLOperatorAttributes1, IMLOperatorKernelCreationContextNodeWrapperPrivate>, onnxruntime::null_type> { public: @@ -374,6 +374,8 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< return E_NOTIMPL; } + // IMLOperatorKernelCreationContextNonGraphNode methods. + uint32_t STDMETHODCALLTYPE GetUtf8NameSizeInBytes() const noexcept override; HRESULT STDMETHODCALLTYPE GetUtf8Name(uint32_t bufferSizeInBytes, char* name) const noexcept override; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 911e043ff6462..497207ab83c46 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -92,7 +92,7 @@ namespace Dml // Static buffer (might truncate name) to avoid excessive dynamic allocation only for debugging purposes. wchar_t nodeName[512]; - ORT_THROW_IF_FAILED(kernelInfo.GetInterfacePrivate()->GetWideName(sizeof(nodeName), nodeName)); + ORT_THROW_IF_FAILED(kernelInfo.GetNodeWrapperInterface()->GetWideName(sizeof(nodeName), nodeName)); ORT_THROW_IF_FAILED(m_compiledOperator->SetName(nodeName)); UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 6d8eae839d807..2e5de08f8744d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -486,6 +486,7 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes MLOperatorKernelCreationContext(IMLOperatorKernelCreationContext* impl) : MLOperatorAttributes(impl), m_impl(impl) { m_impl.As(&m_implPrivate); + m_impl.As(&m_nodeWrapperImpl); } // For cases of interop where the caller needs to pass the unwrapped class across a boundary. @@ -494,9 +495,9 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes return m_impl; } - IMLOperatorKernelCreationContextPrivate* GetInterfacePrivate() const noexcept + IMLOperatorKernelCreationContextNodeWrapperPrivate* GetNodeWrapperInterface() const noexcept { - return m_implPrivate.Get(); + return m_nodeWrapperImpl.Get(); } Microsoft::WRL::ComPtr GetExecutionInterface() const noexcept @@ -562,6 +563,7 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes private: Microsoft::WRL::ComPtr m_impl; Microsoft::WRL::ComPtr m_implPrivate; + Microsoft::WRL::ComPtr m_nodeWrapperImpl; }; class MLShapeInferenceContext : public MLOperatorAttributes diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index d47c4aed0ea98..f48603a91b0c0 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -55,7 +55,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex STDMETHOD(SetDmlOperator)( _In_ const MLOperatorGraphDesc* operatorGraphDesc ) const noexcept PURE; +}; +interface __declspec(uuid("1d2e1226-a918-4236-8775-175cf1f52c9a")) +IMLOperatorKernelCreationContextNodeWrapperPrivate : public IUnknown +{ //! Gets the minimum size of a char buffer to store the node name (including null terminator). //! Returns 1 if the node has no name (calling GetUtf8Name will write a single null terminator). STDMETHOD_(uint32_t, GetUtf8NameSizeInBytes)() const noexcept PURE; @@ -67,7 +71,7 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex ) const noexcept PURE; //! Gets the minimum size of a wchar buffer to store the node name (including null terminator). - //! Returns sizeof(wchar_t) if the node has no name (calling GetWideName will write a single null terminator). + //! Returns sizeof(wchar_t) if the node has no name (calling GetWideName will write a null terminator). STDMETHOD_(uint32_t, GetWideNameSizeInBytes)() const noexcept PURE; //! Writes the node name and null terminator into a wchar buffer. From 8bbdf2a4d2657117e45db790c143d5ee5afef926 Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Mon, 30 Jan 2023 17:52:10 -0800 Subject: [PATCH 4/6] asdf --- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h | 4 ++-- .../dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index c5f35eef776da..6a3382cd2ac70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -331,8 +331,8 @@ class OnnxTensorWrapper : public WRL::Base, public Closable class OpKernelInfoWrapper : public OpNodeInfoWrapper< onnxruntime::ProtoHelperNodeContext, WRL::Base< - Microsoft::WRL::ChainInterfaces, - IMLOperatorTensorShapeDescription, IMLOperatorAttributes1, IMLOperatorKernelCreationContextNodeWrapperPrivate>, + Microsoft::WRL::ChainInterfaces, + IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>, onnxruntime::null_type> { public: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index f48603a91b0c0..6dbe9f10bc3d4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -58,7 +58,7 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex }; interface __declspec(uuid("1d2e1226-a918-4236-8775-175cf1f52c9a")) -IMLOperatorKernelCreationContextNodeWrapperPrivate : public IUnknown +IMLOperatorKernelCreationContextNodeWrapperPrivate : public IMLOperatorKernelCreationContextPrivate { //! Gets the minimum size of a char buffer to store the node name (including null terminator). //! Returns 1 if the node has no name (calling GetUtf8Name will write a single null terminator). From 7cacd4997729e970532753dd93ca686a0519b6ba Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Tue, 31 Jan 2023 13:29:49 -0800 Subject: [PATCH 5/6] cr feedback on naming --- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp | 4 ++-- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h | 6 +++--- .../dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 86d29f8106c79..7b3dd6a9b3216 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -982,7 +982,7 @@ namespace Windows::AI::MachineLearning::Adapter m_abiExecutionObject.CopyTo(executionInterface); } - uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8NameSizeInBytes() const noexcept + uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8NameBufferSizeInBytes() const noexcept { // Include null terminator. return static_cast(m_impl->node().Name().size() + 1); @@ -1006,7 +1006,7 @@ namespace Windows::AI::MachineLearning::Adapter return S_OK; } - uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideNameSizeInBytes() const noexcept + uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideNameBufferSizeInBytes() const noexcept { const auto& name = m_impl->node().Name(); if (name.empty()) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6a3382cd2ac70..ffb8bf1a38860 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -374,12 +374,12 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper< return E_NOTIMPL; } - // IMLOperatorKernelCreationContextNonGraphNode methods. + // IMLOperatorKernelCreationContextNodeWrapperPrivate methods. - uint32_t STDMETHODCALLTYPE GetUtf8NameSizeInBytes() const noexcept override; + uint32_t STDMETHODCALLTYPE GetUtf8NameBufferSizeInBytes() const noexcept override; HRESULT STDMETHODCALLTYPE GetUtf8Name(uint32_t bufferSizeInBytes, char* name) const noexcept override; - uint32_t STDMETHODCALLTYPE GetWideNameSizeInBytes() const noexcept override; + uint32_t STDMETHODCALLTYPE GetWideNameBufferSizeInBytes() const noexcept override; HRESULT STDMETHODCALLTYPE GetWideName(uint32_t bufferSizeInBytes, wchar_t* name) const noexcept override; private: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 6dbe9f10bc3d4..6d906c6056137 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -62,7 +62,7 @@ IMLOperatorKernelCreationContextNodeWrapperPrivate : public IMLOperatorKernelCre { //! Gets the minimum size of a char buffer to store the node name (including null terminator). //! Returns 1 if the node has no name (calling GetUtf8Name will write a single null terminator). - STDMETHOD_(uint32_t, GetUtf8NameSizeInBytes)() const noexcept PURE; + STDMETHOD_(uint32_t, GetUtf8NameBufferSizeInBytes)() const noexcept PURE; //! Writes the node name and null terminator into a char buffer. STDMETHOD(GetUtf8Name)( @@ -72,7 +72,7 @@ IMLOperatorKernelCreationContextNodeWrapperPrivate : public IMLOperatorKernelCre //! Gets the minimum size of a wchar buffer to store the node name (including null terminator). //! Returns sizeof(wchar_t) if the node has no name (calling GetWideName will write a null terminator). - STDMETHOD_(uint32_t, GetWideNameSizeInBytes)() const noexcept PURE; + STDMETHOD_(uint32_t, GetWideNameBufferSizeInBytes)() const noexcept PURE; //! Writes the node name and null terminator into a wchar buffer. STDMETHOD(GetWideName)( From 4a4bd1475b922624ce5854a922772d47b8b5d8ca Mon Sep 17 00:00:00 2001 From: Justin Stoecker Date: Tue, 31 Jan 2023 14:25:36 -0800 Subject: [PATCH 6/6] arg type warnings --- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 7b3dd6a9b3216..af9755ac22031 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1015,7 +1015,7 @@ namespace Windows::AI::MachineLearning::Adapter return sizeof(wchar_t); } - int requiredSizeInChars = MultiByteToWideChar(CP_UTF8, 0, name.data(), name.size(), nullptr, 0); + int requiredSizeInChars = MultiByteToWideChar(CP_UTF8, 0, name.data(), static_cast(name.size()), nullptr, 0); assert(requiredSizeInChars > 0); // Include null terminator. @@ -1038,7 +1038,7 @@ namespace Windows::AI::MachineLearning::Adapter } uint32_t bufferSizeInChars = bufferSizeInBytes / sizeof(wchar_t); - int charsCopiedIfSucceeded = MultiByteToWideChar(CP_UTF8, 0, nodeName.data(), nodeName.size(), outputName, bufferSizeInChars); + int charsCopiedIfSucceeded = MultiByteToWideChar(CP_UTF8, 0, nodeName.data(), static_cast(nodeName.size()), outputName, bufferSizeInChars); if (charsCopiedIfSucceeded > 0) {