Skip to content

Commit

Permalink
fix text_generator_main run error
Browse files Browse the repository at this point in the history
  • Loading branch information
Nigelwz committed Sep 24, 2024
1 parent c9973d2 commit 39a3b07
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ tflite::SignatureRunner* GetSignatureRunner(
std::map<std::string, std::vector<float>>& kv_cache) {
tflite::SignatureRunner* runner =
interpreter->GetSignatureRunner(signature_name.c_str());
int64_t flag = 0;
flag |= kTfLiteCustomAllocationFlagsNone;
flag |= kTfLiteCustomAllocationFlagsSkipAlignCheck;

for (auto& [name, cache] : kv_cache) {
TfLiteCustomAllocation allocation = {
.data = static_cast<void*>(cache.data()),
Expand All @@ -162,9 +166,9 @@ tflite::SignatureRunner* GetSignatureRunner(
// delegates support this in-place update. For those cases, we need to do
// a ping-pong buffer and update the pointers between inference calls.
TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor(
name.c_str(), allocation) == kTfLiteOk);
name.c_str(), allocation, flag) == kTfLiteOk);
TFLITE_MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor(
name.c_str(), allocation) == kTfLiteOk);
name.c_str(), allocation, flag) == kTfLiteOk);
}
TFLITE_MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk);
return runner;
Expand Down

0 comments on commit 39a3b07

Please sign in to comment.