Skip to content

Commit

Permalink
webnn: Replace MLActivation with MLRecurrentNetworkActivation
Browse files Browse the repository at this point in the history
Implements the changes from this spec PR:
webmachinelearning/webnn#718

Bug: 351866391
Cq-Include-Trybots: luci.chromium.try:mac14-blink-rel,mac14.arm64-blink-rel,win11-blink-rel
Change-Id: I163c85af63468a2ad4166509196ac3734d379df2
  • Loading branch information
a-sully authored and chromium-wpt-export-bot committed Jul 28, 2024
1 parent b32746b commit 11fd55e
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 173 deletions.
4 changes: 0 additions & 4 deletions webnn/idlharness.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ idl_test(
ML: ['navigator.ml'],
MLContext: ['context'],
MLOperand: ['input', 'constant', 'output'],
MLActivation: ['relu'],
MLGraphBuilder: ['builder'],
MLGraph: ['graph']
});
Expand All @@ -35,9 +34,6 @@ idl_test(
{dataType: 'float32', dimensions: [2, 3]},
new Float32Array(2 * 3).fill(1));

// Create an activation which won't be used in the graph.
self.relu = builder.relu();

self.output = builder.add(input, constant);

self.graph = await builder.build({output});
Expand Down
24 changes: 2 additions & 22 deletions webnn/resources/utils_validation.js
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,8 @@ function validateOptionsAxes(operationName) {
* @param {String} operationName - An operation name
* @param {Array} supportedDataTypes - Test building with these data types
* succeeds and test building with all other data types fails
* @param {Boolean} alsoBuildActivation - If test building this operation as an
* activation
*/
function validateUnaryOperation(
operationName, supportedDataTypes, alsoBuildActivation = false) {
function validateUnaryOperation(operationName, supportedDataTypes) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
for (let dataType of supportedDataTypes) {
Expand Down Expand Up @@ -451,23 +448,13 @@ function validateUnaryOperation(
}
}
}, `[${operationName}] Throw if the dataType is not supported for an unary operator.`);

if (alsoBuildActivation) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
builder[operationName]();
}, `[${operationName}] Test building an activation`);
}
}

/**
* Validate a single input operation
* @param {String} operationName - An operation name
* @param {Boolean} alsoBuildActivation - If test building this operation as an
* activation
*/
function validateSingleInputOperation(
operationName, alsoBuildActivation = false) {
function validateSingleInputOperation(operationName) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
const supportedDataTypes =
Expand Down Expand Up @@ -503,13 +490,6 @@ function validateSingleInputOperation(
}
}
}, `[${operationName}] Throw if the data type is not supported for the operator.`);

if (alsoBuildActivation) {
promise_test(async t => {
const builder = new MLGraphBuilder(context);
builder[operationName]();
}, `[${operationName}] Test building an activation.`);
}
}

/**
Expand Down
8 changes: 4 additions & 4 deletions webnn/validation_tests/clamp.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ promise_test(async t => {
const output = builder.clamp(input, options);
assert_equals(output.dataType(), 'uint32');
assert_array_equals(output.shape(), [1, 2, 3]);
}, '[clamp] Test building an operator with options');
}, '[clamp] Build with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -40,7 +40,7 @@ promise_test(async t => {
const output = builder.clamp(input, options);
assert_equals(output.dataType(), 'int32');
assert_array_equals(output.shape(), [1, 2, 3, 4]);
}, '[clamp] Test building an operator with options.minValue == options.maxValue');
}, '[clamp] Build with options.minValue == options.maxValue');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -55,7 +55,7 @@ promise_test(async t => {
const input =
builder.input('input', {dataType: 'uint8', dimensions: [1, 2, 3]});
assert_throws_js(TypeError, () => builder.clamp(input, options));
}, '[clamp] Throw if options.minValue > options.maxValue when building an operator');
}, '[clamp] Throw if options.minValue > options.maxValue');

// To be removed once infinite `minValue` is allowed. Tracked in
// https://github.com/webmachinelearning/webnn/pull/647.
Expand All @@ -64,4 +64,4 @@ promise_test(async t => {
const options = {minValue: -Infinity};
const input = builder.input('input', {dataType: 'float16', dimensions: []});
assert_throws_js(TypeError, () => builder.clamp(input, options));
}, '[clamp] Throw if options.minValue is -Infinity when building an operator');
}, '[clamp] Throw if options.minValue is -Infinity');
30 changes: 13 additions & 17 deletions webnn/validation_tests/elu.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

validateInputFromAnotherBuilder('elu');

validateSingleInputOperation('elu', /*alsoBuildActivation=*/ true);
validateSingleInputOperation('elu');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -16,37 +16,33 @@ promise_test(async t => {
const output = builder.elu(input, options);
assert_equals(output.dataType(), 'float32');
assert_array_equals(output.shape(), [1, 2, 3]);
}, '[elu] Test building an operator with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: 1.5};
builder.elu(options);
}, '[elu] Test building an activation with options');
}, '[elu] Build with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: -1.0};
const input =
builder.input('input', {dataType: 'float32', dimensions: [1, 2, 3]});
assert_throws_js(TypeError, () => builder.elu(input, options));
}, '[elu] Throw if options.alpha <= 0 when building an operator');
}, '[elu] Throw if options.alpha < 0');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: NaN};
const input = builder.input('input', {dataType: 'float16', dimensions: []});
const options = {alpha: 0};
const input = builder.input('input', {dataType: 'float32', dimensions: [1]});
assert_throws_js(TypeError, () => builder.elu(input, options));
}, '[elu] Throw if options.alpha is NaN when building an operator');
}, '[elu] Throw if options.alpha == 0');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: 0};
assert_throws_js(TypeError, () => builder.elu(options));
}, '[elu] Throw if options.alpha <= 0 when building an activation');
const options = {alpha: NaN};
const input = builder.input('input', {dataType: 'float16', dimensions: []});
assert_throws_js(TypeError, () => builder.elu(input, options));
}, '[elu] Throw if options.alpha is NaN');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: Infinity};
assert_throws_js(TypeError, () => builder.elu(options));
}, '[elu] Throw if options.alpha is Infinity when building an activation');
const input = builder.input('input', {dataType: 'float32', dimensions: [1]});
assert_throws_js(TypeError, () => builder.elu(input, options));
}, '[elu] Throw if options.alpha is Infinity');
2 changes: 1 addition & 1 deletion webnn/validation_tests/gelu.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

validateInputFromAnotherBuilder('gelu');

validateSingleInputOperation('gelu', /*alsoBuildActivation=*/ true);
validateSingleInputOperation('gelu');
19 changes: 1 addition & 18 deletions webnn/validation_tests/gru.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,7 @@ tests.forEach(
options.layout = test.options.layout;
}
if (test.options.activations) {
options.activations = [];
test.options.activations.forEach(
activation => options.activations.push(builder[activation]()));
options.activations = test.options.activations;
}
}

Expand Down Expand Up @@ -429,18 +427,3 @@ multi_builder_test(async (t, builder, otherBuilder) => {
() => builder.gru(
input, weight, recurrentWeight, steps, hiddenSize, options));
}, '[gru] throw if initialHiddenState option is from another builder');

multi_builder_test(async (t, builder, otherBuilder) => {
const activation = builder.relu();
const activationFromOtherBuilder = otherBuilder.relu();
const options = {activations: [activation, activationFromOtherBuilder]};

const input = builder.input('input', kExampleInputDescriptor);
const weight = builder.input('weight', kExampleWeightDescriptor);
const recurrentWeight =
builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
assert_throws_js(
TypeError,
() => builder.gru(
input, weight, recurrentWeight, steps, hiddenSize, options));
}, '[gru] throw if any activation option is from another builder');
22 changes: 1 addition & 21 deletions webnn/validation_tests/gruCell.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ tests.forEach(
options.layout = test.options.layout;
}
if (test.options.activations) {
options.activations = [];
test.options.activations.forEach(
activation =>
options.activations.push(builder[activation]()));
options.activations = test.options.activations;
}
}

Expand Down Expand Up @@ -457,20 +454,3 @@ multi_builder_test(async (t, builder, otherBuilder) => {
() => builder.gruCell(
input, weight, recurrentWeight, hiddenState, hiddenSize, options));
}, '[gruCell] throw if recurrentBias option is from another builder');

multi_builder_test(async (t, builder, otherBuilder) => {
const activation = builder.relu();
const activationFromOtherBuilder = otherBuilder.relu();
const options = {activations: [activation, activationFromOtherBuilder]};

const input = builder.input('input', kExampleInputDescriptor);
const weight = builder.input('weight', kExampleWeightDescriptor);
const recurrentWeight =
builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
const hiddenState =
builder.input('hiddenState', kExampleHiddenStateDescriptor);
assert_throws_js(
TypeError,
() => builder.gruCell(
input, weight, recurrentWeight, hiddenState, hiddenSize, options));
}, '[gruCell] throw if any activation option is from another builder');
18 changes: 6 additions & 12 deletions webnn/validation_tests/hardSigmoid.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

validateInputFromAnotherBuilder('hardSigmoid');

validateUnaryOperation(
'hardSigmoid', floatingPointTypes, /*alsoBuildActivation=*/ true);
validateUnaryOperation('hardSigmoid', floatingPointTypes);

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -17,23 +16,18 @@ promise_test(async t => {
const output = builder.hardSigmoid(input, options);
assert_equals(output.dataType(), 'float16');
assert_array_equals(output.shape(), [1, 2, 3]);
}, '[hardSigmoid] Test building an operator with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: 0.2};
builder.hardSigmoid(options);
}, '[hardSigmoid] Test building an activation with options');
}, '[hardSigmoid] Test building with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {beta: NaN};
const input = builder.input('input', {dataType: 'float32', dimensions: []});
assert_throws_js(TypeError, () => builder.hardSigmoid(input, options));
}, '[hardSigmoid] Throw if options.beta is NaN when building an operator');
}, '[hardSigmoid] Throw if options.beta is NaN');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: Infinity};
assert_throws_js(TypeError, () => builder.hardSigmoid(options));
}, '[hardSigmoid] Throw if options.alpha is Infinity when building an activation');
const input = builder.input('input', {dataType: 'float32', dimensions: [1]});
assert_throws_js(TypeError, () => builder.hardSigmoid(input, options));
}, '[hardSigmoid] Throw if options.alpha is Infinity');
3 changes: 1 addition & 2 deletions webnn/validation_tests/hardSwish.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@

validateInputFromAnotherBuilder('hardSwish');

validateUnaryOperation(
'hardSwish', floatingPointTypes, /*alsoBuildActivation=*/ true);
validateUnaryOperation('hardSwish', floatingPointTypes);
17 changes: 6 additions & 11 deletions webnn/validation_tests/leakyRelu.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

validateInputFromAnotherBuilder('leakyRelu');

validateSingleInputOperation('leakyRelu', /*alsoBuildActivation=*/ true);
validateSingleInputOperation('leakyRelu');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -16,23 +16,18 @@ promise_test(async t => {
const output = builder.leakyRelu(input, options);
assert_equals(output.dataType(), 'float32');
assert_array_equals(output.shape(), [1, 2, 3]);
}, '[leakyRelu] Test building an operator with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: 0.03};
builder.leakyRelu(options);
}, '[leakyRelu] Test building an activation with options');
}, '[leakyRelu] Build with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: Infinity};
const input = builder.input('input', {dataType: 'float16', dimensions: []});
assert_throws_js(TypeError, () => builder.leakyRelu(input, options));
}, '[leakyRelu] Throw if options.alpha is Infinity when building an operator');
}, '[leakyRelu] Throw if options.alpha is Infinity');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: -NaN};
assert_throws_js(TypeError, () => builder.leakyRelu(options));
}, '[leakyRelu] Throw if options.alpha is -NaN when building an activation');
const input = builder.input('input', {dataType: 'float32', dimensions: [1]});
assert_throws_js(TypeError, () => builder.leakyRelu(input, options));
}, '[leakyRelu] Throw if options.alpha is -NaN');
18 changes: 6 additions & 12 deletions webnn/validation_tests/linear.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

validateInputFromAnotherBuilder('linear');

validateUnaryOperation(
'linear', floatingPointTypes, /*alsoBuildActivation=*/ true);
validateUnaryOperation('linear', floatingPointTypes);

promise_test(async t => {
const builder = new MLGraphBuilder(context);
Expand All @@ -17,23 +16,18 @@ promise_test(async t => {
const output = builder.linear(input, options);
assert_equals(output.dataType(), 'float32');
assert_array_equals(output.shape(), [1, 2, 3]);
}, '[linear] Test building an operator with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {beta: 1.5};
builder.linear(options);
}, '[linear] Test building an activation with options');
}, '[linear] Build with options');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {beta: -Infinity};
const input = builder.input('input', {dataType: 'float16', dimensions: []});
assert_throws_js(TypeError, () => builder.linear(input, options));
}, '[linear] Throw if options.beta is -Infinity when building an operator');
}, '[linear] Throw if options.beta is -Infinity');

promise_test(async t => {
const builder = new MLGraphBuilder(context);
const options = {alpha: NaN};
assert_throws_js(TypeError, () => builder.linear(options));
}, '[linear] Throw if options.alpha is NaN when building an activation');
const input = builder.input('input', {dataType: 'float32', dimensions: [1]});
assert_throws_js(TypeError, () => builder.linear(input, options));
}, '[linear] Throw if options.alpha is NaN');
19 changes: 1 addition & 18 deletions webnn/validation_tests/lstm.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,7 @@ tests.forEach(
options.layout = test.options.layout;
}
if (test.options.activations) {
options.activations = [];
test.options.activations.forEach(
activation => options.activations.push(builder[activation]()));
options.activations = test.options.activations;
}
}

Expand Down Expand Up @@ -455,18 +453,3 @@ multi_builder_test(async (t, builder, otherBuilder) => {
() => builder.lstm(
input, weight, recurrentWeight, steps, hiddenSize, options));
}, '[lstm] throw if initialCellState option is from another builder');

multi_builder_test(async (t, builder, otherBuilder) => {
const activation = builder.relu();
const activationFromOtherBuilder = otherBuilder.relu();
const options = {activations: [activation, activationFromOtherBuilder]};

const input = builder.input('input', kExampleInputDescriptor);
const weight = builder.input('weight', kExampleWeightDescriptor);
const recurrentWeight =
builder.input('recurrentWeight', kExampleRecurrentWeightDescriptor);
assert_throws_js(
TypeError,
() => builder.lstm(
input, weight, recurrentWeight, steps, hiddenSize, options));
}, '[lstm] throw if any activation option is from another builder');
Loading

0 comments on commit 11fd55e

Please sign in to comment.