Skip to content
This repository has been archived by the owner on Oct 22, 2024. It is now read-only.

Add implementations and tests for reduceLogSum and reduceSumSquare ops #223

Merged
merged 2 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/nn/graph_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {LeakyRelu} from './ops/leaky_relu';
import {Pad} from './ops/pad';
import {AveragePool2d, L2Pool2d, MaxPool2d} from './ops/pool2d';
import {PRelu} from './ops/prelu';
import {ReduceL1, ReduceL2, ReduceLogSumExp, ReduceMax, ReduceMean, ReduceMin, ReduceProduct, ReduceSum} from './ops/reduce';
import {ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceMax, ReduceMean, ReduceMin, ReduceProduct, ReduceSum, ReduceSumSquare} from './ops/reduce';
import {Resample2d} from './ops/resample2d';
import {Reshape} from './ops/reshape';
import {Slice} from './ops/slice';
Expand Down Expand Up @@ -759,6 +759,14 @@ export class MLGraphBuilder {
return (new ReduceL2(input, options)).output;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reducelogsum)
*/
reduceLogSum(input: MLOperand, options: MLReduceOptions = {}): MLOperand {
this.validateOperandBuilder([input]);
return (new ReduceLogSum(input, options)).output;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reducelogsumexp)
*/
Expand Down Expand Up @@ -806,6 +814,14 @@ export class MLGraphBuilder {
this.validateOperandBuilder([input]);
return (new ReduceSum(input, options)).output;
}

/**
* [spec](https://webmachinelearning.github.io/webnn/#dom-mlgraphbuilder-reducesumsquare)
*/
reduceSumSquare(input: MLOperand, options: MLReduceOptions = {}): MLOperand {
this.validateOperandBuilder([input]);
return (new ReduceSumSquare(input, options)).output;
}
// end of reduction operations

/**
Expand Down
28 changes: 20 additions & 8 deletions src/nn/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ abstract class Reduce extends SingleOutputOperation {
tf.Tensor;
}

export class ReduceL1 extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.sum(tf.abs(input), axes, keepDimensions);
}
}

export class ReduceL2 extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.sqrt(tf.sum(tf.pow(input, 2), axes, keepDimensions));
}
}

export class ReduceLogSum extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.log(tf.sum(input, axes, keepDimensions));
}
}

export class ReduceLogSumExp extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.logSumExp(input, axes, keepDimensions);
Expand Down Expand Up @@ -100,14 +118,8 @@ export class ReduceSum extends Reduce {
}
}

export class ReduceL1 extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.sum(tf.abs(input), axes, keepDimensions);
}
}

export class ReduceL2 extends Reduce {
export class ReduceSumSquare extends Reduce {
runOp(input: tf.Tensor, axes: number[], keepDimensions: boolean): tf.Tensor {
return tf.sqrt(tf.sum(tf.pow(input, 2), axes, keepDimensions));
return tf.sum(tf.pow(input, 2), axes, keepDimensions);
}
}
Loading