Skip to content

Commit

Permalink
[DML EP] Fix GetInputTensor crash when accessing null tensor (#14811)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Feb 24, 2023
1 parent 3e4a471 commit 8f90066
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1827,9 +1827,8 @@ namespace Windows::AI::MachineLearning::Adapter
ML_CHECK_BOOL(inputIndex < m_inputTensors.size());

auto opKernelContextWrapper = const_cast<OpKernelContextWrapper*>(this);
if (m_inputTensors[inputIndex][0]->GetInterface() == nullptr)
if (m_inputTensors[inputIndex][0] == nullptr)
{
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorType());
auto inputTensor = m_impl->Input<onnxruntime::Tensor>(gsl::narrow_cast<int>(inputIndex));
if (inputTensor != nullptr)
{
Expand Down Expand Up @@ -1867,9 +1866,8 @@ namespace Windows::AI::MachineLearning::Adapter
opKernelContextWrapper->m_inputTensors[inputIndex].resize(sequenceIndex+1);
}

if (m_inputTensors[inputIndex][sequenceIndex]->GetInterface() == nullptr)
if (m_inputTensors[inputIndex][sequenceIndex] == nullptr)
{
assert(m_impl->InputType(gsl::narrow_cast<int>(inputIndex))->IsTensorSequenceType());
auto inputTensorSeq = m_impl->Input<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(inputIndex));
ML_CHECK_BOOL(inputTensorSeq != nullptr);

Expand Down Expand Up @@ -1918,7 +1916,7 @@ namespace Windows::AI::MachineLearning::Adapter
}

// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
if (m_outputTensors[outputIndex][sequenceIndex]->GetInterface() == nullptr)
if (m_outputTensors[outputIndex][sequenceIndex] == nullptr)
{
auto outputTensorSeq = m_impl->Output<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(outputIndex));
ML_CHECK_BOOL(outputTensorSeq != nullptr);
Expand Down Expand Up @@ -2026,7 +2024,7 @@ namespace Windows::AI::MachineLearning::Adapter
ML_CHECK_BOOL(outputIndex < m_outputTensors.size());

// Verify that the provided shape matches the shape determined using the kernel's shape inference function.
if (m_outputTensors[outputIndex][0]->GetInterface() == nullptr)
if (m_outputTensors[outputIndex][0] == nullptr)
{
if (m_outputShapes)
{
Expand Down

0 comments on commit 8f90066

Please sign in to comment.