Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose dimRoundingMode attribute of pool op #5849

Merged
merged 4 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ export const depthwiseConv2dNativeGradConfig: GradConfig = {
`dilations must be 1. Got strides ${strides} and dilations ` +
`'${$dilations}'.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() =>
`Error in depthwiseConv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}
conv_util.checkPadOnDimRoundingMode(
'depthwiseConv2d', pad, dimRoundingMode);

return {
x: () => depthwiseConv2dNativeBackpropInput(
Expand Down
10 changes: 1 addition & 9 deletions tfjs-core/src/ops/avg_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,8 @@ function avgPool_<T extends Tensor3D|Tensor4D>(
util.assert(
x4D.rank === 4,
() => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
const inputs: AvgPoolInputs = {x: x4D};

const attrs: AvgPoolAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {cast} from './cast';
import {op} from './operation';
import {reshape} from './reshape';
Expand Down Expand Up @@ -85,16 +86,8 @@ function avgPool3d_<T extends Tensor4D|Tensor5D>(
dataFormat === 'NDHWC',
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
`but got dataFormat of ${dataFormat}`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
const inputs: AvgPool3DInputs = {x: x5D};

const attrs:
AvgPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat};

Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -77,16 +78,8 @@ function avgPool3dGrad_<T extends Tensor4D|Tensor5D>(
input5D.rank === 5,
() => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
`${input5D.rank}.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3dGrad: pad must be an integer when ` +
`using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};

const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
19 changes: 18 additions & 1 deletion tfjs-core/src/ops/avg_pool_3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,31 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
expect(() => tf.avgPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});

it('throws when dimRoundingMode is set and pad is not a number', async () => {
it('throws when dimRoundingMode is set and pad is same', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'valid';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});
Expand Down
65 changes: 64 additions & 1 deletion tfjs-core/src/ops/avg_pool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expectArraysClose(await result.data(), [2.5, 3, 4, 4.5, 5.5, 6, 7, 7.5]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1);

expect(result.shape).toEqual([1, 1, 3]);
});

it('x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
Expand All @@ -112,6 +120,30 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expect(result.shape).toEqual([2, 2, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'floor');

expect(result.shape).toEqual([1, 1, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'round');

expect(result.shape).toEqual([2, 2, 3]);
});

it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', () => {
// Feed forward.
const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]);
const result = tf.avgPool(x, 2, 3, 1, 'ceil');

expect(result.shape).toEqual([2, 2, 3]);
});

it('gradient x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => {
const x = tf.tensor3d([0], [1, 1, 1]);
const dy = tf.tensor3d([0], [1, 1, 1]);
Expand Down Expand Up @@ -178,7 +210,16 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
]);
});

it('throws when dimRoundingMode is set and pad is not a number', () => {
it('throws when dimRoundingMode is set and pad is same', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'valid';
Expand All @@ -187,6 +228,28 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool({} as tf.Tensor3D, 2, 1, 'valid'))
.toThrowError(/Argument 'x' passed to 'avgPool' must be a Tensor/);
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,7 @@ function conv1d_<T extends Tensor2D|Tensor3D>(
$filter.rank === 3,
() => `Error in conv1d: filter must be rank 3, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv1d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
util.assert(
x3D.shape[2] === $filter.shape[1],
() => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
Expand Down
84 changes: 84 additions & 0 deletions tfjs-core/src/ops/conv1d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,90 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
expectArraysClose(await result.data(), await expectedResult.data());
});

it('throws when dimRoundingMode is set and pad is same', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'same';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'valid';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 1.2;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('TensorLike', async () => {
const pad = 'same';
const stride = 1;
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
$filter.rank === 4,
() => `Error in conv2d: filter must be rank 4, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
util.assert(
inDepth === $filter.shape[2],
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d_backprop_filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,7 @@ function conv2DBackpropFilter_<T extends Tensor3D|Tensor4D>(
outDepth === filterShape[3],
() => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
`match output depth for filter (${filterShape[3]}).`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerFilter: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
const inputs: Conv2DBackpropFilterInputs = {x: x4D, dy: dy4D};
const attrs: Conv2DBackpropFilterAttrs =
{strides, pad, dataFormat, dimRoundingMode, filterShape};
Expand Down
9 changes: 1 addition & 8 deletions tfjs-core/src/ops/conv2d_backprop_input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,8 @@ function conv2DBackpropInput_<T extends Tensor3D|Tensor4D>(
outDepth === filter.shape[3],
() => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
`match output depth for filter ${filter.shape[3]}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerInput: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
const inputs: Conv2DBackpropInputInputs = {dy: dy4D, filter};

const attrs: Conv2DBackpropInputAttrs =
{strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D};

Expand Down
Loading