Skip to content

Commit

Permalink
Expose dimRoundingMode attribute of pool op (#5849)
Browse files Browse the repository at this point in the history
* Expose dimRoundingMode attribute of pool op.

* Add conv_util.checkPadOnDimRoundingMode() function with tests.

Co-authored-by: Na Li <linazhao@google.com>
  • Loading branch information
BruceDai and lina128 authored Nov 30, 2021
1 parent 2dcda6f commit 0bab1d8
Show file tree
Hide file tree
Showing 29 changed files with 915 additions and 135 deletions.
9 changes: 2 additions & 7 deletions tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ export const depthwiseConv2dNativeGradConfig: GradConfig = {
`dilations must be 1. Got strides ${strides} and dilations ` +
`'${$dilations}'.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() =>
`Error in depthwiseConv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}
conv_util.checkPadOnDimRoundingMode(
'depthwiseConv2d', pad, dimRoundingMode);

return {
x: () => depthwiseConv2dNativeBackpropInput(
Expand Down
10 changes: 1 addition & 9 deletions tfjs-core/src/ops/avg_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,8 @@ function avgPool_<T extends Tensor3D|Tensor4D>(
util.assert(
x4D.rank === 4,
() => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
const inputs: AvgPoolInputs = {x: x4D};

const attrs: AvgPoolAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {cast} from './cast';
import {op} from './operation';
import {reshape} from './reshape';
Expand Down Expand Up @@ -85,16 +86,8 @@ function avgPool3d_<T extends Tensor4D|Tensor5D>(
dataFormat === 'NDHWC',
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
`but got dataFormat of ${dataFormat}`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
const inputs: AvgPool3DInputs = {x: x5D};

const attrs:
AvgPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat};

Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -77,16 +78,8 @@ function avgPool3dGrad_<T extends Tensor4D|Tensor5D>(
input5D.rank === 5,
() => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
`${input5D.rank}.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3dGrad: pad must be an integer when ` +
`using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};

const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
19 changes: 18 additions & 1 deletion tfjs-core/src/ops/avg_pool_3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,31 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
expect(() => tf.avgPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});

it('throws when dimRoundingMode is set and pad is not a number', async () => {
it('throws when dimRoundingMode is set and pad is same', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'valid';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});
Expand Down
65 changes: 64 additions & 1 deletion tfjs-core/src/ops/avg_pool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expectArraysClose(await result.data(), [2.5, 3, 4, 4.5, 5.5, 6, 7, 7.5]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1);

expect(result.shape).toEqual([1, 1, 3]);
});

it('x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
Expand All @@ -112,6 +120,30 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expect(result.shape).toEqual([2, 2, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'floor');

expect(result.shape).toEqual([1, 1, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'round');

expect(result.shape).toEqual([2, 2, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'ceil');

expect(result.shape).toEqual([2, 2, 3]);
});

it('gradient x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => {
const x = tf.tensor3d([0], [1, 1, 1]);
const dy = tf.tensor3d([0], [1, 1, 1]);
Expand Down Expand Up @@ -178,7 +210,16 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
]);
});

it('throws when dimRoundingMode is set and pad is not a number', () => {
it('throws when dimRoundingMode is set and pad is same', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'valid';
Expand All @@ -187,6 +228,28 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool({} as tf.Tensor3D, 2, 1, 'valid'))
.toThrowError(/Argument 'x' passed to 'avgPool' must be a Tensor/);
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,7 @@ function conv1d_<T extends Tensor2D|Tensor3D>(
$filter.rank === 3,
() => `Error in conv1d: filter must be rank 3, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv1d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
util.assert(
x3D.shape[2] === $filter.shape[1],
() => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
Expand Down
84 changes: 84 additions & 0 deletions tfjs-core/src/ops/conv1d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,90 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
expectArraysClose(await result.data(), await expectedResult.data());
});

it('throws when dimRoundingMode is set and pad is same', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'same';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'valid';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 1.2;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('TensorLike', async () => {
const pad = 'same';
const stride = 1;
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
$filter.rank === 4,
() => `Error in conv2d: filter must be rank 4, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
util.assert(
inDepth === $filter.shape[2],
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d_backprop_filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,7 @@ function conv2DBackpropFilter_<T extends Tensor3D|Tensor4D>(
outDepth === filterShape[3],
() => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
`match output depth for filter (${filterShape[3]}).`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerFilter: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
const inputs: Conv2DBackpropFilterInputs = {x: x4D, dy: dy4D};
const attrs: Conv2DBackpropFilterAttrs =
{strides, pad, dataFormat, dimRoundingMode, filterShape};
Expand Down
9 changes: 1 addition & 8 deletions tfjs-core/src/ops/conv2d_backprop_input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,8 @@ function conv2DBackpropInput_<T extends Tensor3D|Tensor4D>(
outDepth === filter.shape[3],
() => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
`match output depth for filter ${filter.shape[3]}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerInput: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
const inputs: Conv2DBackpropInputInputs = {dy: dy4D, filter};

const attrs: Conv2DBackpropInputAttrs =
{strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D};

Expand Down
Loading

0 comments on commit 0bab1d8

Please sign in to comment.