Skip to content

Commit

Permalink
Support lstm and lstmCell operations
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed May 8, 2023
1 parent 4650499 commit 425f377
Show file tree
Hide file tree
Showing 5 changed files with 808 additions and 0 deletions.
67 changes: 67 additions & 0 deletions src/nn/graph_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {HardSigmoid} from './ops/hard_sigmoid';
import {InstanceNormalization} from './ops/instance_norm';
import {LeakyRelu} from './ops/leaky_relu';
import {Linear} from './ops/linear';
import {Lstm, LstmCell} from './ops/lstm';
import {Pad} from './ops/pad';
import {AveragePool2d, L2Pool2d, MaxPool2d} from './ops/pool2d';
import {PRelu} from './ops/prelu';
Expand Down Expand Up @@ -211,9 +212,43 @@ export interface MLLinearOptions {
beta?: number;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#enumdef-mllstmweightlayout)
*/
export enum MLLstmWeightLayout {
'iofg' = 'iofg', // input-output-forget-cell gate ordering
'ifgo' = 'ifgo' // input-forget-cell-output gate ordering
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#enumdef-mlpaddingmode)
*/
export interface MLLstmOptions {
bias?: MLOperand;
recurrentBias?: MLOperand;
peepholeWeight?: MLOperand;
initialHiddenState?: MLOperand;
initialCellState?: MLOperand;
returnSequence?: boolean;
direction?: MLRecurrentNetworkDirection;
layout?: MLLstmWeightLayout;
activations?: MLActivation[];
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dictdef-mllstmcelloptions)
*/
export interface MLLstmCellOptions {
bias?: MLOperand;
recurrentBias?: MLOperand;
peepholeWeight?: MLOperand;
layout?: MLLstmWeightLayout;
activations?: MLActivation[];
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dictdef-mllstmoptions)
*/
export enum MLPaddingMode {
'constant' = 'constant',
'edge' = 'edge',
Expand Down Expand Up @@ -769,6 +804,38 @@ export class MLGraphBuilder {
}
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-lstm)
*/
lstm(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand,
steps: number, hiddenSize: number, options: MLLstmOptions = {}):
MLOperand[] {
this.validateOperandBuilder([
input, weight, recurrentWeight, options.bias, options.recurrentBias,
options.peepholeWeight, options.initialHiddenState,
options.initialCellState
]);
return (new Lstm(
input, weight, recurrentWeight, steps, hiddenSize, options))
.outputs;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-lstmcell)
*/
lstmCell(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand,
hiddenState: MLOperand, cellState: MLOperand, hiddenSize: number,
options: MLLstmCellOptions = {}): MLOperand[] {
this.validateOperandBuilder([
input, weight, recurrentWeight, hiddenState, cellState, options.bias,
options.recurrentBias, options.peepholeWeight
]);
return (new LstmCell(
input, weight, recurrentWeight, hiddenState, cellState,
hiddenSize, options))
.outputs;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-matmul)
*/
Expand Down
Loading

0 comments on commit 425f377

Please sign in to comment.