Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DML EP Use ORT node names in DML graphs/ops #14461

Merged
merged 6 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ namespace DmlGraphFusionHelper
{
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{graphDesc.nodes[i].op.Get()};
auto& nodeInfo = graphDesc.nodes[i];
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ namespace Dml::GraphDescBuilder

NodeInfo nodeInfo = {};
nodeInfo.op = std::move(op);
nodeInfo.name = node.Name();
graphNodes.push_back(std::move(nodeInfo));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace Dml
struct NodeInfo
{
Microsoft::WRL::ComPtr<IDMLOperator> op;
std::string name;
};

struct GraphDesc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,85 @@ namespace Windows::AI::MachineLearning::Adapter
m_abiExecutionObject.CopyTo(executionInterface);
}

uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8NameSizeInBytes() const noexcept
{
// Include null terminator.
const auto& name = m_impl->node().Name();
return name.empty() ? 0 : name.size() + 1;
jstoecker marked this conversation as resolved.
Show resolved Hide resolved
}

HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetUtf8Name(uint32_t bufferSizeInBytes, char* outputName) const noexcept
{
if (bufferSizeInBytes == 0)
{
return E_INVALIDARG;
}

// Copy as many characters as possible, leaving room for the null terminator.
const auto& nodeName = m_impl->node().Name();
size_t charsCopied = nodeName.copy(outputName, bufferSizeInBytes - 1);

// Write the null terminator.
assert(charsCopied >= 0 && charsCopied < bufferSizeInBytes);
outputName[charsCopied] = '\0';

return S_OK;
}

uint32_t STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideNameSizeInBytes() const noexcept
{
const auto& name = m_impl->node().Name();
if (name.empty())
{
return 0;
}

int requiredSize = MultiByteToWideChar(CP_UTF8, 0, name.data(), name.size(), nullptr, 0);
assert(requiredSize > 0);

// Include null terminator.
return static_cast<uint32_t>(requiredSize) + 1;
}

HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetWideName(uint32_t bufferSizeInBytes, wchar_t* outputName) const noexcept
{
// Buffer needs to be large enough to at least hold a null terminator.
if (bufferSizeInBytes < sizeof(wchar_t))
{
return E_INVALIDARG;
}

const auto& nodeName = m_impl->node().Name();
if (nodeName.empty())
{
outputName[0] = L'\0';
return S_OK;
}

uint32_t bufferSizeInChars = bufferSizeInBytes / sizeof(wchar_t);
int charsCopiedIfSucceeded = MultiByteToWideChar(CP_UTF8, 0, nodeName.data(), nodeName.size(), outputName, bufferSizeInChars);

auto lastError = GetLastError();
jstoecker marked this conversation as resolved.
Show resolved Hide resolved
if (lastError == ERROR_INSUFFICIENT_BUFFER)
{
// Buffer was too small. Truncate and overwrite last char with null terminator.
outputName[bufferSizeInChars - 1] = '\0';
return S_OK;
}

if (charsCopiedIfSucceeded == 0)
jstoecker marked this conversation as resolved.
Show resolved Hide resolved
{
assert(lastError == ERROR_INVALID_PARAMETER || lastError == ERROR_NO_UNICODE_TRANSLATION);
return E_INVALIDARG;
}

// All characters copied successfully. Write null terminator at the end of copied chars.
assert(lastError == 0);
outputName[charsCopiedIfSucceeded] = '\0';

return S_OK;
}

template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
uint32_t STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetInputCount() const noexcept
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
class OpKernelInfoWrapper : public OpNodeInfoWrapper<
onnxruntime::ProtoHelperNodeContext,
WRL::Base<
Microsoft::WRL::ChainInterfaces<IMLOperatorKernelCreationContextPrivate, IMLOperatorKernelCreationContext>,
Microsoft::WRL::ChainInterfaces<IMLOperatorKernelCreationContextPrivate1, IMLOperatorKernelCreationContextPrivate, IMLOperatorKernelCreationContext>,
IMLOperatorTensorShapeDescription, IMLOperatorAttributes1>,
onnxruntime::null_type>
{
Expand Down Expand Up @@ -374,6 +374,12 @@ class OpKernelInfoWrapper : public OpNodeInfoWrapper<
return E_NOTIMPL;
}

uint32_t STDMETHODCALLTYPE GetUtf8NameSizeInBytes() const noexcept override;
HRESULT STDMETHODCALLTYPE GetUtf8Name(uint32_t bufferSizeInBytes, char* name) const noexcept override;

uint32_t STDMETHODCALLTYPE GetWideNameSizeInBytes() const noexcept override;
HRESULT STDMETHODCALLTYPE GetWideName(uint32_t bufferSizeInBytes, wchar_t* name) const noexcept override;

private:
// For shape info, in addition to the info
const EdgeShapes* m_inferredOutputShapes = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ namespace Dml
{
DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator)));

// Static buffer (might truncate name) to avoid excessive dynamic allocation only for debugging purposes.
wchar_t nodeName[512];
ORT_THROW_IF_FAILED(kernelInfo.GetInterfacePrivate()->GetWideName(sizeof(nodeName), nodeName));
ORT_THROW_IF_FAILED(m_compiledOperator->SetName(nodeName));

UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize;
if (persistentResourceSize > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes
return m_impl;
}

Microsoft::WRL::ComPtr<IMLOperatorKernelCreationContextPrivate1> GetInterfacePrivate() const noexcept
{
return m_implPrivate;
}

Microsoft::WRL::ComPtr<IUnknown> GetExecutionInterface() const noexcept
{
Microsoft::WRL::ComPtr<IUnknown> ret;
Expand Down Expand Up @@ -556,7 +561,7 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes

private:
Microsoft::WRL::ComPtr<IMLOperatorKernelCreationContext> m_impl;
Microsoft::WRL::ComPtr<IMLOperatorKernelCreationContextPrivate> m_implPrivate;
Microsoft::WRL::ComPtr<IMLOperatorKernelCreationContextPrivate1> m_implPrivate;
};

class MLShapeInferenceContext : public MLOperatorAttributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex
) const noexcept PURE;
};

interface __declspec(uuid("1d2e1226-a918-4236-8775-175cf1f52c9a"))
IMLOperatorKernelCreationContextPrivate1 : public IMLOperatorKernelCreationContextPrivate
jstoecker marked this conversation as resolved.
Show resolved Hide resolved
jstoecker marked this conversation as resolved.
Show resolved Hide resolved
{
//! Gets the minimum size of a char buffer to store the node name (including null terminator).
//! Returns 0 if the node has no name.
STDMETHOD_(uint32_t, GetUtf8NameSizeInBytes)() const noexcept PURE;

jstoecker marked this conversation as resolved.
Show resolved Hide resolved
//! Writes the node name and null terminator into a char buffer.
STDMETHOD(GetUtf8Name)(
uint32_t bufferSizeInBytes,
_Out_writes_(bufferSizeInBytes) char* name
) const noexcept PURE;

//! Gets the minimum size of a wchar buffer to store the node name (including null terminator).
//! Returns 0 if the node has no name.
STDMETHOD_(uint32_t, GetWideNameSizeInBytes)() const noexcept PURE;

//! Writes the node name and null terminator into a wchar buffer.
STDMETHOD(GetWideName)(
uint32_t bufferSizeInBytes,
_Out_writes_(bufferSizeInBytes) wchar_t* name
) const noexcept PURE;
};

//! \interface IMLOperatorAttributes1
//! \brief Represents the values of an operator's attributes, as determined by a model using the operator.
//! This interface is called by implementations of custom operator kernels, and by implementations
Expand Down