Skip to content

Commit

Permalink
fix: Validate input type to make sure a TypedArray is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
mrousavy committed Jan 24, 2024
1 parent cbb76b9 commit 0481955
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
17 changes: 11 additions & 6 deletions cpp/TensorHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ size_t TensorHelpers::getTFLTensorDataTypeSize(TfLiteType dataType) {
return sizeof(uint16_t);
default:
[[unlikely]];
throw std::runtime_error("Unsupported output data type! " + dataTypeToString(dataType));
throw std::runtime_error("TFLite: Unsupported output data type! " +
dataTypeToString(dataType));
}
}

Expand Down Expand Up @@ -152,7 +153,8 @@ TypedArrayBase TensorHelpers::createJSBufferForTensor(jsi::Runtime& runtime,
return TypedArray<TypedArrayKind::Uint32Array>(runtime, size);
default:
[[unlikely]];
throw std::runtime_error("Unsupported tensor data type! " + dataTypeToString(dataType));
throw std::runtime_error("TFLite: Unsupported tensor data type! " +
dataTypeToString(dataType));
}
}

Expand All @@ -164,7 +166,7 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa
void* data = TfLiteTensorData(tensor);
if (data == nullptr) {
[[unlikely]];
throw std::runtime_error("Failed to get data from tensor \"" + name + "\"!");
throw std::runtime_error("TFLite: Failed to get data from tensor \"" + name + "\"!");
}

// count of bytes, may be larger than count of numbers (e.g. for float32)
Expand Down Expand Up @@ -213,19 +215,21 @@ void TensorHelpers::updateJSBufferFromTensor(jsi::Runtime& runtime, TypedArrayBa
break;
default:
[[unlikely]];
throw jsi::JSError(runtime, "Unsupported output data type! " + dataTypeToString(dataType));
throw jsi::JSError(runtime,
"TFLite: Unsupported output data type! " + dataTypeToString(dataType));
}
}

void TensorHelpers::updateTensorFromJSBuffer(jsi::Runtime& runtime, TfLiteTensor* tensor,
TypedArrayBase& jsBuffer) {
#if DEBUG
// Validate data-type
TypedArrayKind kind = jsBuffer.getKind(runtime);
TfLiteType receivedType = getTFLDataTypeForTypedArrayKind(kind);
TfLiteType expectedType = TfLiteTensorType(tensor);
if (receivedType != expectedType) {
[[unlikely]];
throw std::runtime_error("Invalid input type! Model expected " +
throw std::runtime_error("TFLite: Invalid input type! Model expected " +
dataTypeToString(expectedType) + ", but received " +
dataTypeToString(receivedType) + "!");
}
Expand All @@ -235,11 +239,12 @@ void TensorHelpers::updateTensorFromJSBuffer(jsi::Runtime& runtime, TfLiteTensor
jsi::ArrayBuffer buffer = jsBuffer.getBuffer(runtime);

#if DEBUG
// Validate size
int inputBufferSize = buffer.size(runtime);
int tensorSize = getTensorTotalLength(tensor) * getTFLTensorDataTypeSize(tensor->type);
if (tensorSize != inputBufferSize) {
[[unlikely]];
throw std::runtime_error("Input Buffer size (" + std::to_string(inputBufferSize) +
throw std::runtime_error("TFLite: Input Buffer size (" + std::to_string(inputBufferSize) +
") does not match the Input Tensor's expected size (" +
std::to_string(tensorSize) +
")! Make sure to resize the input values accordingly.");
Expand Down
36 changes: 25 additions & 11 deletions cpp/TensorflowPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ TensorflowPlugin::TensorflowPlugin(TfLiteInterpreter* interpreter, Buffer model,
TfLiteStatus status = TfLiteInterpreterAllocateTensors(_interpreter);
if (status != kTfLiteOk) {
[[unlikely]];
throw std::runtime_error("Failed to allocate memory for input/output tensors! Status: " +
tfLiteStatusToString(status));
throw std::runtime_error(
"TFLite: Failed to allocate memory for input/output tensors! Status: " +
tfLiteStatusToString(status));
}

log("Successfully created Tensorflow Plugin!");
Expand Down Expand Up @@ -205,23 +206,33 @@ void TensorflowPlugin::copyInputBuffers(jsi::Runtime& runtime, jsi::Object input
#if DEBUG
if (!inputValues.isArray(runtime)) {
[[unlikely]];
throw std::runtime_error(
"TFLite: Input Values must be an array, one item for each input tensor!");
throw jsi::JSError(runtime,
"TFLite: Input Values must be an array, one item for each input tensor!");
}
#endif

jsi::Array array = inputValues.asArray(runtime);
size_t count = array.size(runtime);
if (count != TfLiteInterpreterGetInputTensorCount(_interpreter)) {
[[unlikely]];
throw std::runtime_error(
"TFLite: Input Values have different size than there are input tensors!");
throw jsi::JSError(runtime,
"TFLite: Input Values have different size than there are input tensors!");
}

for (size_t i = 0; i < count; i++) {
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
jsi::Value value = array.getValueAtIndex(runtime, i);
TypedArrayBase inputBuffer = getTypedArray(runtime, value.asObject(runtime));
jsi::Object object = array.getValueAtIndex(runtime, i).asObject(runtime);

#if DEBUG
if (!isTypedArray(runtime, object)) {
[[unlikely]];
throw jsi::JSError(
runtime,
"TFLite: Input value is not a TypedArray! (Uint8Array, Uint16Array, Float32Array, etc.)");
}
#endif

TypedArrayBase inputBuffer = getTypedArray(runtime, std::move(object));
TensorHelpers::updateTensorFromJSBuffer(runtime, tensor, inputBuffer);
}
}
Expand All @@ -244,7 +255,8 @@ void TensorflowPlugin::run() {
TfLiteStatus status = TfLiteInterpreterInvoke(_interpreter);
if (status != kTfLiteOk) {
[[unlikely]];
throw std::runtime_error("Failed to run TFLite Model! Status: " + tfLiteStatusToString(status));
throw std::runtime_error("TFLite: Failed to run TFLite Model! Status: " +
tfLiteStatusToString(status));
}
}

Expand Down Expand Up @@ -296,7 +308,8 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
TfLiteTensor* tensor = TfLiteInterpreterGetInputTensor(_interpreter, i);
if (tensor == nullptr) {
[[unlikely]];
throw jsi::JSError(runtime, "Failed to get input tensor " + std::to_string(i) + "!");
throw jsi::JSError(runtime,
"TFLite: Failed to get input tensor " + std::to_string(i) + "!");
}

jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);
Expand All @@ -310,7 +323,8 @@ jsi::Value TensorflowPlugin::get(jsi::Runtime& runtime, const jsi::PropNameID& p
const TfLiteTensor* tensor = TfLiteInterpreterGetOutputTensor(_interpreter, i);
if (tensor == nullptr) {
[[unlikely]];
throw jsi::JSError(runtime, "Failed to get output tensor " + std::to_string(i) + "!");
throw jsi::JSError(runtime,
"TFLite: Failed to get output tensor " + std::to_string(i) + "!");
}

jsi::Object object = TensorHelpers::tensorToJSObject(runtime, tensor);
Expand Down
2 changes: 1 addition & 1 deletion src/TensorflowLite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export function loadTensorflowModel(
uri = source.url
} else {
throw new Error(
'Invalid source passed! Source should be either a React Native require(..) or a `{ url: string }` object!'
'TFLite: Invalid source passed! Source should be either a React Native require(..) or a `{ url: string }` object!'
)
}
return global.__loadTensorflowModel(uri, delegate)
Expand Down

0 comments on commit 0481955

Please sign in to comment.