Skip to content

Commit e0b3004

Browse files
authored
Unify inference chains (#1399)
Used for tensor op registry and model inference
1 parent 28852a2 commit e0b3004

File tree

3 files changed

+30
-18
lines changed

3 files changed

+30
-18
lines changed

src/backends/onnx.js

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,32 @@ export async function createInferenceSession(buffer_or_path, session_options, se
160160
return session;
161161
}
162162

163+
164+
/**
165+
* Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
166+
* For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
167+
* @type {Promise<any>}
168+
*/
169+
let webInferenceChain = Promise.resolve();
170+
171+
const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV;
172+
173+
/**
174+
* Run an inference session.
175+
* @param {import('onnxruntime-common').InferenceSession} session The ONNX inference session.
176+
* @param {Record<string, import('onnxruntime-common').Tensor>} ortFeed The input tensors.
177+
* @returns {Promise<Record<string, import('onnxruntime-common').Tensor>>} The output tensors.
178+
*/
179+
export async function runInferenceSession(session, ortFeed) {
180+
const run = () => session.run(ortFeed);
181+
const output = await (IS_WEB_ENV
182+
? (webInferenceChain = webInferenceChain.then(run))
183+
: run()
184+
);
185+
return output;
186+
}
187+
188+
163189
/**
164190
* Check if an object is an ONNX tensor.
165191
* @param {any} x The object to check

src/models.js

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ import {
4848
createInferenceSession,
4949
isONNXTensor,
5050
isONNXProxy,
51+
runInferenceSession,
5152
} from './backends/onnx.js';
5253
import {
5354
DATA_TYPES,
@@ -419,10 +420,6 @@ function validateInputs(session, inputs) {
419420
return checkedInputs;
420421
}
421422

422-
// Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
423-
// For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
424-
let webInferenceChain = Promise.resolve();
425-
426423
/**
427424
* Executes an InferenceSession using the specified inputs.
428425
* NOTE: `inputs` must contain at least the input names of the model.
@@ -439,10 +436,7 @@ async function sessionRun(session, inputs) {
439436
try {
440437
// pass the original ort tensor
441438
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
442-
const run = () => session.run(ortFeed);
443-
const output = await ((apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV)
444-
? (webInferenceChain = webInferenceChain.then(run))
445-
: run());
439+
const output = await runInferenceSession(session, ortFeed);
446440
return replaceTensors(output);
447441
} catch (e) {
448442
// Error messages can be long (nested) and uninformative. For this reason,

src/ops/registry.js

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import { createInferenceSession, isONNXProxy } from "../backends/onnx.js";
1+
import { createInferenceSession, runInferenceSession, isONNXProxy } from "../backends/onnx.js";
22
import { Tensor } from "../utils/tensor.js";
3-
import { apis } from "../env.js";
43

5-
const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV;
64
/**
75
* Asynchronously creates a wrapper function for running an ONNX inference session.
86
*
@@ -19,16 +17,10 @@ const wrap = async (session_bytes, session_options, names) => {
1917
new Uint8Array(session_bytes), session_options,
2018
);
2119

22-
/** @type {Promise<any>} */
23-
let chain = Promise.resolve();
24-
2520
return /** @type {any} */(async (/** @type {Record<string, Tensor>} */ inputs) => {
2621
const proxied = isONNXProxy();
2722
const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, (proxied ? v.clone() : v).ort_tensor]));
28-
29-
// When running in-browser via WASM, we need to chain calls to session.run to avoid "Error: Session already started"
30-
const outputs = await (chain = IS_WEB_ENV ? chain.then(() => session.run(ortFeed)) : session.run(ortFeed));
31-
23+
const outputs = await runInferenceSession(session, ortFeed);
3224
if (Array.isArray(names)) {
3325
return names.map((n) => new Tensor(outputs[n]));
3426
} else {

0 commit comments

Comments
 (0)