Skip to content
This repository was archived by the owner on Oct 22, 2024. It is now read-only.

Update interface MLOperator to MLActivation #203

Merged
merged 1 commit into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/nn/graph_builder.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {MLContext} from './context';
import {MLGraph} from './graph';
import {ConstantOperand, InputOperand, MLOperand, MLOperandDescriptor, MLOperandType} from './operand';
import {MLOperator} from './operation';
import {MLActivation} from './operation';
import {BatchNormalization} from './ops/batch_norm';
import {Add, Div, MatMul, Max, Min, Mul, Pow, Sub} from './ops/binary';
import {Clamp} from './ops/clamp';
Expand Down Expand Up @@ -42,7 +42,7 @@ export interface MLBatchNormalizationOptions {
bias?: MLOperand;
axis?: number;
epsilon?: number;
activation?: MLOperator;
activation?: MLActivation;
}

/**
Expand Down Expand Up @@ -84,7 +84,7 @@ export interface MLConv2dOptions {
inputLayout?: MLInputOperandLayout;
filterLayout?: MLConv2dFilterOperandLayout;
bias?: MLOperand;
activation?: MLOperator;
activation?: MLActivation;
}

/**
Expand All @@ -110,7 +110,7 @@ export interface MLConvTranspose2dOptions {
inputLayout?: MLInputOperandLayout;
filterLayout?: MLConvTranspose2dFilterOperandLayout;
bias?: MLOperand;
activation?: MLOperator;
activation?: MLActivation;
}

/**
Expand Down Expand Up @@ -152,7 +152,7 @@ export interface MLGruOptions {
returnSequence?: boolean;
direction?: MLRecurrentNetworkDirection;
layout?: MLRecurrentNetworkWeightLayout;
activations?: MLOperator[];
activations?: MLActivation[];
}

/**
Expand All @@ -163,7 +163,7 @@ export interface MLGruCellOptions {
recurrentBias?: MLOperand;
resetAfter?: boolean;
layout?: MLRecurrentNetworkWeightLayout;
activations?: MLOperator[];
activations?: MLActivation[];
}

/**
Expand Down Expand Up @@ -378,10 +378,10 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-clamp)
*/
clamp(x: MLOperand, options: MLClampOptions): MLOperand;
clamp(options: MLClampOptions): MLOperator;
clamp(options: MLClampOptions): MLActivation;
clamp(
operandOrOptions: MLOperand|MLClampOptions = {},
options: MLClampOptions = {}): MLOperand|MLOperator {
options: MLClampOptions = {}): MLOperand|MLActivation {
if (operandOrOptions instanceof MLOperand) {
const x = operandOrOptions;
this.validateOperandBuilder([x]);
Expand Down Expand Up @@ -565,8 +565,8 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-hard-swish)
*/
hardSwish(input: MLOperand): MLOperand;
hardSwish(): MLOperator;
hardSwish(input: MLOperand = undefined): MLOperand|MLOperator {
hardSwish(): MLActivation;
hardSwish(input: MLOperand = undefined): MLOperand|MLActivation {
if (input === undefined) {
return new HardSwish(undefined);
} else {
Expand All @@ -579,8 +579,8 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-relu)
*/
relu(input: MLOperand): MLOperand;
relu(): MLOperator;
relu(input: MLOperand = undefined): MLOperand|MLOperator {
relu(): MLActivation;
relu(input: MLOperand = undefined): MLOperand|MLActivation {
if (input === undefined) {
return new Relu(undefined);
} else {
Expand All @@ -593,8 +593,8 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-sigmoid)
*/
sigmoid(input: MLOperand): MLOperand;
sigmoid(): MLOperator;
sigmoid(input: MLOperand = undefined): MLOperand|MLOperator {
sigmoid(): MLActivation;
sigmoid(input: MLOperand = undefined): MLOperand|MLActivation {
if (input === undefined) {
return new Sigmoid(undefined);
} else {
Expand All @@ -607,8 +607,8 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-tanh)
*/
tanh(input: MLOperand): MLOperand;
tanh(): MLOperator;
tanh(input: MLOperand = undefined): MLOperand|MLOperator {
tanh(): MLActivation;
tanh(input: MLOperand = undefined): MLOperand|MLActivation {
if (input === undefined) {
return new Tanh(undefined);
} else {
Expand Down Expand Up @@ -671,10 +671,10 @@ export class MLGraphBuilder {
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-leakyrelu)
*/
leakyRelu(x: MLOperand, options: MLLeakyReluOptions): MLOperand;
leakyRelu(options: MLLeakyReluOptions): MLOperator;
leakyRelu(options: MLLeakyReluOptions): MLActivation;
leakyRelu(
operandOrOptions: MLOperand|MLLeakyReluOptions = {},
options: MLLeakyReluOptions = {}): MLOperand|MLOperator {
options: MLLeakyReluOptions = {}): MLOperand|MLActivation {
if (operandOrOptions instanceof MLOperand) {
const x = operandOrOptions;
this.validateOperandBuilder([x]);
Expand Down
4 changes: 2 additions & 2 deletions src/nn/operation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import {MLGraphBuilder} from './graph_builder';
import {MLOperand, OutputOperand} from './operand';

/**
* [spec](https://webmachinelearning.github.io/webnn/#api-mloperator)
* [spec](https://webmachinelearning.github.io/webnn/#api-mlactivation)
*/
export interface MLOperator {
export interface MLActivation {
/** @internal */
apply(input: MLOperand): OutputOperand;

Expand Down
4 changes: 2 additions & 2 deletions src/nn/ops/batch_norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as tf from '@tensorflow/tfjs-core';

import {MLBatchNormalizationOptions} from '../graph_builder';
import {MLOperand, OutputOperand} from '../operand';
import {FusedOperation, MLOperator, SingleOutputOperation} from '../operation';
import {FusedOperation, MLActivation, SingleOutputOperation} from '../operation';
import * as utils from '../utils';

export class BatchNormalization extends SingleOutputOperation implements
Expand All @@ -14,7 +14,7 @@ export class BatchNormalization extends SingleOutputOperation implements
private bias_?: MLOperand;
private axis_?: number;
private epsilon_?: number;
private activation_?: MLOperator;
private activation_?: MLActivation;
private needCheckOutputShape_ = true;

constructor(
Expand Down
4 changes: 2 additions & 2 deletions src/nn/ops/clamp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import * as tf from '@tensorflow/tfjs-core';

import {MLClampOptions} from '../graph_builder';
import {MLOperand} from '../operand';
import {UnaryMLOperator} from './unary';
import {UnaryMLActivation} from './unary';
import * as utils from '../utils';

export class Clamp extends UnaryMLOperator {
export class Clamp extends UnaryMLActivation {
private minValue_?: number;
private maxValue_?: number;

Expand Down
8 changes: 4 additions & 4 deletions src/nn/ops/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {ExplicitPadding} from '@tensorflow/tfjs-core/dist/ops/conv_util';

import {MLAutoPad, MLConv2dOptions, MLConv2dFilterOperandLayout, MLInputOperandLayout} from '../graph_builder';
import {ConstantOperand, MLOperand, OutputOperand} from '../operand';
import {FusedOperation, MLOperator, SingleOutputOperation} from '../operation';
import {FusedOperation, MLActivation, SingleOutputOperation} from '../operation';
import * as utils from '../utils';

import {Clamp} from './clamp';
Expand All @@ -21,7 +21,7 @@ export class Conv2d extends SingleOutputOperation implements FusedOperation {
private inputLayout_?: MLInputOperandLayout;
private filterLayout_?: MLConv2dFilterOperandLayout;
private autoPad_?: MLAutoPad;
private activation_?: MLOperator;
private activation_?: MLActivation;
private fusedActivation_?: tf.fused.Activation;
private leakyreluAlpha_?: number;
private filterTensor_?: tf.Tensor4D;
Expand All @@ -48,7 +48,7 @@ export class Conv2d extends SingleOutputOperation implements FusedOperation {
filterLayout:
MLConv2dFilterOperandLayout = MLConv2dFilterOperandLayout.oihw,
autoPad: MLAutoPad = MLAutoPad.explicit, bias: MLOperand = undefined,
activation: MLOperator = undefined) {
activation: MLActivation = undefined) {
utils.assert(
utils.isIntegerArray(padding) && padding.length === 4,
'The padding parameter is invalid.');
Expand Down Expand Up @@ -104,7 +104,7 @@ export class Conv2d extends SingleOutputOperation implements FusedOperation {
}
}

isRelu6(activation: MLOperator): boolean {
isRelu6(activation: MLActivation): boolean {
if (activation instanceof Clamp) {
const clamp = activation;
if (Math.abs(clamp.minValue - 0.0) < 1e-5 &&
Expand Down
8 changes: 4 additions & 4 deletions src/nn/ops/conv_transpose2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {ExplicitPadding} from '@tensorflow/tfjs-core/dist/ops/conv_util';

import {MLAutoPad, MLConvTranspose2dOptions, MLConvTranspose2dFilterOperandLayout, MLInputOperandLayout} from '../graph_builder';
import {ConstantOperand, MLOperand, OutputOperand} from '../operand';
import {FusedOperation, MLOperator, SingleOutputOperation} from '../operation';
import {FusedOperation, MLActivation, SingleOutputOperation} from '../operation';
import * as utils from '../utils';

import {Clamp} from './clamp';
Expand All @@ -24,7 +24,7 @@ export class ConvTranspose2d extends SingleOutputOperation
private autoPad_?: MLAutoPad;
private outputPadding_?: [number, number];
private outputSizes_?: [number, number];
private activation_?: MLOperator;
private activation_?: MLActivation;
private fusedActivation_?: tf.fused.Activation;
private leakyreluAlpha_?: number;
private filterTensor_?: tf.Tensor4D;
Expand Down Expand Up @@ -55,7 +55,7 @@ export class ConvTranspose2d extends SingleOutputOperation
autoPad: MLAutoPad = MLAutoPad.explicit,
outputPadding: [number, number] = [0, 0],
outputSizes: [number, number] = undefined, bias: MLOperand = undefined,
activation: MLOperator = undefined) {
activation: MLActivation = undefined) {

utils.assert(
utils.isIntegerArray(padding) && padding.length === 4,
Expand Down Expand Up @@ -133,7 +133,7 @@ export class ConvTranspose2d extends SingleOutputOperation
}
}

isRelu6(activation: MLOperator): boolean {
isRelu6(activation: MLActivation): boolean {
if (activation instanceof Clamp) {
const clamp = activation;
if (Math.abs(clamp.minValue - 0.0) < 1e-5 &&
Expand Down
23 changes: 12 additions & 11 deletions src/nn/ops/gru.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import * as tf from '@tensorflow/tfjs-core';

import {MLGruCellOptions, MLGruOptions, MLRecurrentNetworkDirection, MLRecurrentNetworkWeightLayout} from '../graph_builder';
import {MLOperand, OutputOperand} from '../operand';
import {MLOperator, Operation, SingleOutputOperation} from '../operation';
import {MLActivation, Operation, SingleOutputOperation} from '../operation';
import * as utils from '../utils';
import {UnaryMLOperator} from './unary';
import {UnaryMLActivation} from './unary';

export class Gru extends Operation {
private input_: MLOperand;
Expand All @@ -19,7 +19,7 @@ export class Gru extends Operation {
private returnSequence_: boolean;
private direction_: MLRecurrentNetworkDirection;
private layout_: MLRecurrentNetworkWeightLayout;
private activations_: MLOperator[];
private activations_: MLActivation[];
private needCheckOutputShape_ = true;

constructor(
Expand Down Expand Up @@ -58,7 +58,8 @@ export class Gru extends Operation {
layout:
MLRecurrentNetworkWeightLayout = MLRecurrentNetworkWeightLayout.zrn,
activations:
MLOperator[] = [this.builder.sigmoid(), this.builder.tanh()]): void {
MLActivation[] = [this.builder.sigmoid(), this.builder.tanh()]):
void {
utils.validateOptionalOperand(bias);
this.bias_ = bias;
utils.validateOptionalOperand(recurrentBias);
Expand All @@ -83,7 +84,7 @@ export class Gru extends Operation {
this.layout_ = layout;
utils.assert(
activations instanceof Array && activations.length === 2 &&
activations.every(a => a instanceof UnaryMLOperator),
activations.every(a => a instanceof UnaryMLActivation),
'The activations parameter is invalid.');
this.activations_ = activations;
}
Expand Down Expand Up @@ -217,7 +218,7 @@ export class GruCell extends SingleOutputOperation {
private recurrentBias_?: MLOperand;
private resetAfter_: boolean;
private layout_: MLRecurrentNetworkWeightLayout;
private activations_: MLOperator[];
private activations_: MLActivation[];
private needCheckOutputShape_ = true;

constructor(
Expand Down Expand Up @@ -247,7 +248,7 @@ export class GruCell extends SingleOutputOperation {
layout:
MLRecurrentNetworkWeightLayout = MLRecurrentNetworkWeightLayout.zrn,
activations:
MLOperator[] = [this.builder.sigmoid(), this.builder.tanh()]) {
MLActivation[] = [this.builder.sigmoid(), this.builder.tanh()]) {
utils.validateOptionalOperand(bias);
this.bias_ = bias;
utils.validateOptionalOperand(recurrentBias);
Expand All @@ -262,7 +263,7 @@ export class GruCell extends SingleOutputOperation {
this.layout_ = layout;
utils.assert(
activations instanceof Array && activations.length === 2 &&
activations.every(a => a instanceof UnaryMLOperator),
activations.every(a => a instanceof UnaryMLActivation),
'The activations parameter is invalid.');
this.activations_ = activations;
}
Expand All @@ -281,7 +282,7 @@ export class GruCell extends SingleOutputOperation {

static compute(
input: tf.Tensor, weight: tf.Tensor, recurrentWeight: tf.Tensor,
hiddenState: tf.Tensor, hiddenSize: number, activations: MLOperator[],
hiddenState: tf.Tensor, hiddenSize: number, activations: MLActivation[],
bias?: tf.Tensor, recurrentBias?: tf.Tensor, resetAfter = true,
layout:
MLRecurrentNetworkWeightLayout = MLRecurrentNetworkWeightLayout.zrn):
Expand All @@ -291,8 +292,8 @@ export class GruCell extends SingleOutputOperation {
const starts = layout === MLRecurrentNetworkWeightLayout.zrn ?
{z: 0, r: hiddenSize, n: 2 * hiddenSize} :
/*rzn*/ {r: 0, z: hiddenSize, n: 2 * hiddenSize};
const activation0: UnaryMLOperator = activations[0] as UnaryMLOperator;
const activation1: UnaryMLOperator = activations[1] as UnaryMLOperator;
const activation0: UnaryMLActivation = activations[0] as UnaryMLActivation;
const activation1: UnaryMLActivation = activations[1] as UnaryMLActivation;
// update gate
const z = activation0.runOp(tf.add(
tf.add(
Expand Down
4 changes: 2 additions & 2 deletions src/nn/ops/leaky_relu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ import * as tf from '@tensorflow/tfjs-core';
import {MLOperand} from '../operand';
import * as utils from '../utils';

import {UnaryMLOperator} from './unary';
import {UnaryMLActivation} from './unary';

export class LeakyRelu extends UnaryMLOperator {
export class LeakyRelu extends UnaryMLActivation {
private alpha_?: number;

get alpha(): number {
Expand Down
12 changes: 6 additions & 6 deletions src/nn/ops/unary.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as tf from '@tensorflow/tfjs-core';

import {MLOperand, OutputOperand} from '../operand';
import {MLOperator, SingleOutputOperation} from '../operation';
import {MLActivation, SingleOutputOperation} from '../operation';
import * as utils from '../utils';

export abstract class Unary extends SingleOutputOperation {
Expand Down Expand Up @@ -92,7 +92,7 @@ export class Tan extends Unary {
}
}

export abstract class UnaryMLOperator extends Unary implements MLOperator {
export abstract class UnaryMLActivation extends Unary implements MLActivation {
apply(x: MLOperand): OutputOperand {
this.builder_ = x.builder;
utils.validateOperand(x);
Expand All @@ -102,25 +102,25 @@ export abstract class UnaryMLOperator extends Unary implements MLOperator {
}
}

export class Sigmoid extends UnaryMLOperator {
export class Sigmoid extends UnaryMLActivation {
runOp(x: tf.Tensor): tf.Tensor {
return tf.sigmoid(x);
}
}

export class Tanh extends UnaryMLOperator {
export class Tanh extends UnaryMLActivation {
runOp(x: tf.Tensor): tf.Tensor {
return tf.tanh(x);
}
}

export class Relu extends UnaryMLOperator {
export class Relu extends UnaryMLActivation {
runOp(x: tf.Tensor): tf.Tensor {
return tf.relu(x);
}
}

export class HardSwish extends UnaryMLOperator {
export class HardSwish extends UnaryMLActivation {
runOp(x: tf.Tensor): tf.Tensor {
return tf.div(
tf.mul(
Expand Down