Skip to content

Commit

Permalink
fix up lambda capture changes
Browse files Browse the repository at this point in the history
  • Loading branch information
edgchen1 committed Sep 23, 2024
1 parent 197957c commit b7ff4e0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
std::atomic_bool failed{false};

int n = batch_size * sequence_length;

// Put epsilon into local variable here to avoid the need to capture 'this' in the TryBatchParallelFor() lambda.
// Using the copy capture default (=) to implicitly capture 'this' is deprecated.
const float epsilon_value = epsilon();

concurrency::ThreadPool::TryBatchParallelFor(
context->GetOperatorThreadPool(), n, [=, &failed, this](ptrdiff_t index) {
context->GetOperatorThreadPool(), n, [=, &failed](ptrdiff_t index) {
int word_col_index = input_ids_data[index];
if (word_col_index < 0 || word_col_index >= word_embedding_length) {
failed.store(true, std::memory_order_release);
Expand Down Expand Up @@ -136,7 +141,7 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
y[i] = a;
sum += a * a;
}
T e = sqrt(sum / hidden_size + static_cast<T>(epsilon()));
T e = sqrt(sum / hidden_size + static_cast<T>(epsilon_value));
for (int i = 0; i < hidden_size; i++) {
y[i] = y[i] / e * gamma_data[i] + beta_data[i];
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2770,7 +2770,8 @@ common::Status InferenceSession::RunAsync(const RunOptions* run_options,
if (!tp || concurrency::ThreadPool::DegreeOfParallelism(tp) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
}
std::function<void()> run_fn = [=, this]() {
std::function<void()> run_fn = [run_options, feed_names, feeds, fetch_names, fetches, num_fetches,
callback, user_data, this]() {
Status status = Status::OK();
ORT_TRY {
if (run_options) {
Expand Down

0 comments on commit b7ff4e0

Please sign in to comment.