Skip to content

Commit

Permalink
[webnn] Add float32 tests for WebNN API where op (#43558)
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai authored Dec 13, 2023
1 parent b597a15 commit ba654bf
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
10 changes: 10 additions & 0 deletions webnn/gpu/where.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// META: title=test WebNN API where operation
// META: global=window,dedicatedworker
// META: script=../resources/utils.js
// META: timeout=long

'use strict';

// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-where

testWebNNOperation('where', buildWhere, 'gpu');
10 changes: 10 additions & 0 deletions webnn/resources/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ const PrecisionMetrics = {
squeeze: {ULP: {float32: 0, float16: 0}},
tanh: {ATOL: {float32: 1/1024, float16: 1/512}},
transpose: {ULP: {float32: 0, float16: 0}},
where: {ULP: {float32: 0, float16: 0}},
};

/**
Expand Down Expand Up @@ -694,6 +695,15 @@ const buildSplit = (operationName, builder, resources) => {
return namedOutputOperand;
};

const buildWhere = (operationName, builder, resources) => {
// MLOperand where(MLOperand condition, MLOperand trueValues, MLOperand falseValues);
const namedOutputOperand = {};
const [conditionOperand, trueValuesOperand, falseValuesOperand] = createMultiInputOperands(builder, resources);
// invoke builder.where()
namedOutputOperand[resources.expected.name] = builder[operationName](conditionOperand, trueValuesOperand, falseValuesOperand);
return namedOutputOperand;
};

/**
* Build a graph.
* @param {String} operationName - An operation name
Expand Down
10 changes: 10 additions & 0 deletions webnn/where.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// META: title=test WebNN API where operation
// META: global=window,dedicatedworker
// META: script=./resources/utils.js
// META: timeout=long

'use strict';

// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-where

testWebNNOperation('where', buildWhere);

0 comments on commit ba654bf

Please sign in to comment.