Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ static Status ComputeRange(OpKernelContext* ctx) {
}

Status Range::Compute(OpKernelContext* ctx) const {
auto data_type = ctx->Input<Tensor>(0)->DataType();
auto input_tensor = ctx->Input<Tensor>(0);
if (input_tensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
auto data_type = input_tensor->DataType();
if (data_type == DataTypeImpl::GetType<int32_t>()) {
return ComputeRange<int32_t>(ctx);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/string_normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ Status StringNormalizer::Compute(OpKernelContext* ctx) const {
using namespace string_normalizer;

auto X = ctx->Input<Tensor>(0);
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
auto& input_dims = X->Shape().GetDims();

size_t N = 0;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ Status Tokenizer::SeparatorTokenize(OpKernelContext* ctx,
Status Tokenizer::Compute(OpKernelContext* ctx) const {
// Get input buffer ptr
auto X = ctx->Input<Tensor>(0);
if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
if (X->DataType() != DataTypeImpl::GetType<std::string>()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"tensor(string) expected as input");
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/core/providers/cpu/tensor/eye_like.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,18 @@ Status EyeLike::Compute(OpKernelContext* context) const {
auto output_tensor_dtype = has_dtype_ ? static_cast<onnx::TensorProto::DataType>(dtype_) : utils::GetTensorProtoType(*T1);
switch (output_tensor_dtype) {
case onnx::TensorProto_DataType_FLOAT:
return ComputeImpl<float>(context);
return ComputeImpl<float>(context, T1);
case onnx::TensorProto_DataType_INT64:
return ComputeImpl<int64_t>(context);
return ComputeImpl<int64_t>(context, T1);
case onnx::TensorProto_DataType_UINT64:
return ComputeImpl<uint64_t>(context);
return ComputeImpl<uint64_t>(context, T1);
default:
ONNXRUNTIME_THROW("Unsupported 'dtype' value: ", output_tensor_dtype);
}
}

template <typename T>
Status EyeLike::ComputeImpl(OpKernelContext* context) const {
const Tensor* T1 = context->Input<Tensor>(0);
Status EyeLike::ComputeImpl(OpKernelContext* context, const Tensor* T1) const {
const std::vector<int64_t>& input_dims = T1->Shape().GetDims();
if (input_dims.size() != 2) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "EyeLike : Input tensor dimension is not 2");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/eye_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EyeLike final : public OpKernel {

private:
template <typename T>
Status ComputeImpl(OpKernelContext* context) const;
Status ComputeImpl(OpKernelContext* context, const Tensor* T1) const;

bool has_dtype_;
int64_t dtype_;
Expand Down