Skip to content

Commit

Permalink
WebNN: Support AllowSharedBufferSource for constant
Browse files Browse the repository at this point in the history
This CL implements the WebNN spec change proposal [1] that uses
`AllowSharedBufferSource` for `MLGraphBuilder.constant()`.

[1]: webmachinelearning/webnn#790

Bug: 380896836
Change-Id: Ib8fc58daaabf7493b08f9634daa1eeb08a50ad35
Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac15.arm64-blink-rel, linux-blink-rel
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/6050722
Reviewed-by: Austin Sullivan <asully@chromium.org>
Reviewed-by: Weizhong Xia <weizhong@google.com>
Commit-Queue: ningxin hu <ningxin.hu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1390750}
  • Loading branch information
huningxin authored and chromium-wpt-export-bot committed Dec 3, 2024
1 parent 6a83407 commit 2643d92
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 30 deletions.
71 changes: 71 additions & 0 deletions webnn/conformance_tests/shared_arraybuffer_constant.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// META: title=test WebNN API constant with shared array buffer
// META: global=window,dedicatedworker
// META: variant=?cpu
// META: variant=?gpu
// META: variant=?npu
// META: script=../resources/utils_validation.js
// META: script=../resources/utils.js
// META: timeout=long

'use strict';

// Skip tests if WebNN is unimplemented.
promise_setup(async () => {
assert_implements(navigator.ml, 'missing navigator.ml');
});

// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-constant-buffer

const testContents = Int32Array.from([0, 1, 2, 3, 4, 5, 6, 7]);
const sharedArrayBuffer = new SharedArrayBuffer(testContents.byteLength);
const typedArray = new Int32Array(sharedArrayBuffer);
typedArray.set(testContents);

let mlContext;
let mlGraph;
let outputTensor;
promise_setup(async () => {
try {
mlContext = await navigator.ml.createContext(contextOptions);
} catch (e) {
throw new AssertionError(
`Unable to create mlContext for ${variant} variant. ${e}`);
}

try {
outputTensor = await mlContext.createTensor({
dataType: 'int32',
shape: [8],
readable: true,
});
} catch (e) {
throw new AssertionError(
`Unable to create tensor for ${variant} variant. ${e}`);
}
});

promise_test(async () => {
const builder = new MLGraphBuilder(mlContext);
const constant =
builder.constant({dataType: 'int32', shape: [8]}, sharedArrayBuffer);
const output = builder.identity(constant);
const mlGraph = await builder.build({output});

mlContext.dispatch(mlGraph, {}, {output: outputTensor});
const results = new Int32Array(await mlContext.readTensor(outputTensor));

assert_array_equals(results, testContents);
}, `constant() with a SharedArrayBuffer`);

promise_test(async () => {
const builder = new MLGraphBuilder(mlContext);
const constant =
builder.constant({dataType: 'int32', shape: [8]}, typedArray);
const output = builder.identity(constant);
const mlGraph = await builder.build({output});

mlContext.dispatch(mlGraph, {}, {output: outputTensor});
const results = new Int32Array(await mlContext.readTensor(outputTensor));

assert_array_equals(results, testContents);
}, `constant() with a typeArray from a SharedArrayBuffer`);
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cross-Origin-Embedder-Policy: require-corp
Cross-Origin-Opener-Policy: same-origin
98 changes: 68 additions & 30 deletions webnn/validation_tests/constant.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,125 +8,152 @@
'use strict';

const tests = [
// Tests for constant(descriptor, bufferView)
// Tests for constant(descriptor, buffer)
{
name:
'[constant] Test building a 0-D scalar constant with empty dimensions',
descriptor: {dataType: 'float32', shape: []},
bufferView: {type: Float32Array, byteLength: 1 * 4},
buffer: {type: Float32Array, byteLength: 1 * 4},
output: {dataType: 'float32', shape: []}
},
{
name: '[constant] Test building a constant with float32 data type',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {type: Float32Array, byteLength: 6 * 4},
buffer: {type: Float32Array, byteLength: 6 * 4},
output: {dataType: 'float32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for float32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of float32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {
buffer: {
type: Float32Array,
byteLength: 6 * 4 - 4 // The bufferView's byte length is less than the
byteLength: 6 * 4 - 4 // The buffer's byte length is less than the
// one by given dimensions
}
},
// TODO (crbug.com/329702838): Test building a constant with float16 data type
{
name: '[constant] Test building a constant with int32 data type',
descriptor: {dataType: 'int32', shape: [2, 3]},
bufferView: {type: Int32Array, byteLength: 6 * 4},
buffer: {type: Int32Array, byteLength: 6 * 4},
output: {dataType: 'int32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int32', shape: [2, 3]},
bufferView: {
buffer: {
type: Int32Array,
byteLength: 6 * 4 + 4 // The bufferView's byte length is greater than the
byteLength: 6 * 4 + 4 // The buffer's byte length is greater than the
// one by given dimensions
}
},
{
name: '[constant] Test building a constant with uint32 data type',
descriptor: {dataType: 'uint32', shape: [2, 3]},
bufferView: {type: Uint32Array, byteLength: 6 * 4},
buffer: {type: Uint32Array, byteLength: 6 * 4},
output: {dataType: 'uint32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint32', shape: [2, 3]},
bufferView: {type: Uint32Array, byteLength: 6 * 4 + 4}
buffer: {type: Uint32Array, byteLength: 6 * 4 + 4}
},
{
name: '[constant] Test building a constant with int64 data type',
descriptor: {dataType: 'int64', shape: [2, 3]},
bufferView: {type: BigInt64Array, byteLength: 6 * 8},
buffer: {type: BigInt64Array, byteLength: 6 * 8},
output: {dataType: 'int64', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int64 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int64 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int64', shape: [2, 3]},
bufferView: {type: BigInt64Array, byteLength: 6 * 8 + 8}
buffer: {type: BigInt64Array, byteLength: 6 * 8 + 8}
},
{
name: '[constant] Test building a constant with uint64 data type',
descriptor: {dataType: 'uint64', shape: [2, 3]},
bufferView: {type: BigUint64Array, byteLength: 6 * 8},
buffer: {type: BigUint64Array, byteLength: 6 * 8},
output: {dataType: 'uint64', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint64 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint64 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint64', shape: [2, 3]},
bufferView: {type: BigUint64Array, byteLength: 6 * 8 + 8}
buffer: {type: BigUint64Array, byteLength: 6 * 8 + 8}
},
{
name: '[constant] Test building a constant with int8 data type',
descriptor: {dataType: 'int8', shape: [2, 3]},
bufferView: {type: Int8Array, byteLength: 6 * 1},
buffer: {type: Int8Array, byteLength: 6 * 1},
output: {dataType: 'int8', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int8 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int8 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int8', shape: [2, 3]},
bufferView: {type: Int8Array, byteLength: 6 * 4 - 4}
buffer: {type: Int8Array, byteLength: 6 * 4 - 4}
},
{
name: '[constant] Test building a constant with uint8 data type',
descriptor: {dataType: 'uint8', shape: [2, 3]},
bufferView: {type: Uint8Array, byteLength: 6 * 1},
buffer: {type: Uint8Array, byteLength: 6 * 1},
output: {dataType: 'uint8', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint8 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint8 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint8', shape: [2, 3]},
bufferView: {type: Uint8Array, byteLength: 6 * 4 - 4}
buffer: {type: Uint8Array, byteLength: 6 * 4 - 4}
},
{
name: '[constant] Throw if a dimension is 0',
descriptor: {dataType: 'float32', shape: [2, 0]},
bufferView: {type: Float32Array, byteLength: 2 * 4}
buffer: {type: Float32Array, byteLength: 2 * 4}
},
{
name:
'[constant] Throw if bufferView type doesn\'t match the operand data type',
'[constant] Throw if buffer view\'s type doesn\'t match the operand data type',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {type: Int32Array, byteLength: 6 * 4}
buffer: {type: Int32Array, byteLength: 6 * 4},
viewTestOnly: true
}
];

tests.forEach(
test => promise_test(async t => {
const builder = new MLGraphBuilder(context);
const buffer = new ArrayBuffer(test.bufferView.byteLength);
const bufferView = new test.bufferView.type(buffer);
const buffer = new ArrayBuffer(test.buffer.byteLength);
const bufferView = new test.buffer.type(buffer);
const sharedBuffer = new SharedArrayBuffer(test.buffer.byteLength);
const sharedBufferView = new test.buffer.type(sharedBuffer);

if (test.viewTestOnly === undefined || test.viewTestOnly === false) {
// Test building constant from ArrayBuffer.
if (test.output) {
const constantOperand = builder.constant(test.descriptor, buffer);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, buffer));
}
// Test building constant from SharedArrayBuffer.
if (test.output) {
const constantOperand =
builder.constant(test.descriptor, sharedBuffer);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, sharedBuffer));
}
}

// Test building constant from ArrayBufferView.
if (test.output) {
const constantOperand = builder.constant(test.descriptor, bufferView);
assert_equals(constantOperand.dataType, test.output.dataType);
Expand All @@ -135,4 +162,15 @@ tests.forEach(
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, bufferView));
}
// Test building constant from shared ArrayBufferView.
if (test.output) {
const constantOperand =
builder.constant(test.descriptor, sharedBufferView);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError,
() => builder.constant(test.descriptor, sharedBufferView));
}
}, test.name));
2 changes: 2 additions & 0 deletions webnn/validation_tests/constant.https.any.js.headers
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cross-Origin-Embedder-Policy: require-corp
Cross-Origin-Opener-Policy: same-origin

0 comments on commit 2643d92

Please sign in to comment.