Skip to content

Commit

Permalink
Add conv_util.checkPadOnDimRoundingMode() function with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Nov 24, 2021
1 parent 424c41c commit 0af59ea
Show file tree
Hide file tree
Showing 25 changed files with 683 additions and 161 deletions.
9 changes: 2 additions & 7 deletions tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ export const depthwiseConv2dNativeGradConfig: GradConfig = {
`dilations must be 1. Got strides ${strides} and dilations ` +
`'${$dilations}'.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() =>
`Error in depthwiseConv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}
conv_util.checkPadOnDimRoundingMode(
'depthwiseConv2d', pad, dimRoundingMode);

return {
x: () => depthwiseConv2dNativeBackpropInput(
Expand Down
26 changes: 1 addition & 25 deletions tfjs-core/src/ops/avg_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,32 +72,8 @@ function avgPool_<T extends Tensor3D|Tensor4D>(
util.assert(
x4D.rank === 4,
() => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`);

if (dimRoundingMode != null) {
if (typeof pad === 'string') {
throw Error(
`Error in avgPool: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
} else if (typeof pad === 'number') {
util.assert(
util.isInt(pad),
() => `Error in avgPool: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
} else if (typeof pad === 'object') {
(pad as conv_util.ExplicitPadding).forEach(p => {p.forEach(v =>{
util.assert(
util.isInt(v),
() => `Error in avgPool: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
});
});
} else {
throw Error(`Unknown padding parameter: ${pad}`);
}
}

conv_util.checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode);
const inputs: AvgPoolInputs = {x: x4D};

const attrs: AvgPoolAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {cast} from './cast';
import {op} from './operation';
import {reshape} from './reshape';
Expand Down Expand Up @@ -85,16 +86,8 @@ function avgPool3d_<T extends Tensor4D|Tensor5D>(
dataFormat === 'NDHWC',
() => `Error in avgPool3d: Only NDHWC is currently supported, ` +
`but got dataFormat of ${dataFormat}`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode);
const inputs: AvgPool3DInputs = {x: x5D};

const attrs:
AvgPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat};

Expand Down
11 changes: 2 additions & 9 deletions tfjs-core/src/ops/avg_pool_3d_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {checkPadOnDimRoundingMode} from './conv_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -77,16 +78,8 @@ function avgPool3dGrad_<T extends Tensor4D|Tensor5D>(
input5D.rank === 5,
() => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
`${input5D.rank}.`);

if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in avgPool3dGrad: pad must be an integer when ` +
`using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D};

const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode};

// tslint:disable-next-line: no-unnecessary-type-assertion
Expand Down
19 changes: 18 additions & 1 deletion tfjs-core/src/ops/avg_pool_3d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,31 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => {
expect(() => tf.avgPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});

it('throws when dimRoundingMode is set and pad is not a number', async () => {
it('throws when dimRoundingMode is set and pad is same', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 'valid';
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
async () => {
const x = tf.tensor5d([1], [1, 1, 1, 1, 1]);
const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError();
});
Expand Down
33 changes: 32 additions & 1 deletion tfjs-core/src/ops/avg_pool_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,16 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
]);
});

it('throws when dimRoundingMode is set and pad is not a number', () => {
it('throws when dimRoundingMode is set and pad is same', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'same';
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 'valid';
Expand All @@ -219,6 +228,28 @@ describeWithFlags('avgPool', ALL_ENVS, () => {
expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = 1.2;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]);

const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const dimRoundingMode = 'round';

expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError();
});

it('throws when passed a non-tensor', () => {
expect(() => tf.avgPool({} as tf.Tensor3D, 2, 1, 'valid'))
.toThrowError(/Argument 'x' passed to 'avgPool' must be a Tensor/);
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,7 @@ function conv1d_<T extends Tensor2D|Tensor3D>(
$filter.rank === 3,
() => `Error in conv1d: filter must be rank 3, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv1d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode);
util.assert(
x3D.shape[2] === $filter.shape[1],
() => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` +
Expand Down
84 changes: 84 additions & 0 deletions tfjs-core/src/ops/conv1d_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,90 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
expectArraysClose(await result.data(), await expectedResult.data());
});

it('throws when dimRoundingMode is set and pad is same', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'same';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is valid', () => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 'valid';
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is a non-integer number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = 1.2;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('throws when dimRoundingMode is set and pad is explicit by non-integer ' +
'number',
() => {
const inputDepth = 1;
const inputShape: [number, number, number] = [2, 2, inputDepth];
const outputDepth = 1;
const fSize = 1;
const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as
tf.backend_util.ExplicitPadding;
const stride = 1;
const dataFormat = 'NWC';
const dilation = 1;
const dimRoundingMode = 'round';

const x = tf.tensor3d([1, 2, 3, 4], inputShape);
const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]);

expect(
() => tf.conv1d(
x, w, stride, pad, dataFormat, dilation, dimRoundingMode))
.toThrowError();
});

it('TensorLike', async () => {
const pad = 'same';
const stride = 1;
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,7 @@ function conv2d_<T extends Tensor3D|Tensor4D>(
$filter.rank === 4,
() => `Error in conv2d: filter must be rank 4, but got rank ` +
`${$filter.rank}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2d: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
util.assert(
inDepth === $filter.shape[2],
Expand Down
8 changes: 1 addition & 7 deletions tfjs-core/src/ops/conv2d_backprop_filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,7 @@ function conv2DBackpropFilter_<T extends Tensor3D|Tensor4D>(
outDepth === filterShape[3],
() => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
`match output depth for filter (${filterShape[3]}).`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerFilter: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
const inputs: Conv2DBackpropFilterInputs = {x: x4D, dy: dy4D};
const attrs: Conv2DBackpropFilterAttrs =
{strides, pad, dataFormat, dimRoundingMode, filterShape};
Expand Down
9 changes: 1 addition & 8 deletions tfjs-core/src/ops/conv2d_backprop_input.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,8 @@ function conv2DBackpropInput_<T extends Tensor3D|Tensor4D>(
outDepth === filter.shape[3],
() => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
`match output depth for filter ${filter.shape[3]}.`);
if (dimRoundingMode != null) {
util.assert(
util.isInt(pad as number),
() => `Error in conv2dDerInput: pad must be an integer when using, ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}

conv_util.checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
const inputs: Conv2DBackpropInputInputs = {dy: dy4D, filter};

const attrs: Conv2DBackpropInputAttrs =
{strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D};

Expand Down
Loading

0 comments on commit 0af59ea

Please sign in to comment.