diff --git a/tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts b/tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts index 88d6bd344bc..c2b977ec4f8 100644 --- a/tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts +++ b/tfjs-core/src/gradients/DepthwiseConv2dNative_grad.ts @@ -58,13 +58,8 @@ export const depthwiseConv2dNativeGradConfig: GradConfig = { `dilations must be 1. Got strides ${strides} and dilations ` + `'${$dilations}'.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => - `Error in depthwiseConv2d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } + conv_util.checkPadOnDimRoundingMode( + 'depthwiseConv2d', pad, dimRoundingMode); return { x: () => depthwiseConv2dNativeBackpropInput( diff --git a/tfjs-core/src/ops/avg_pool.ts b/tfjs-core/src/ops/avg_pool.ts index 97b611990c4..f29853c63a8 100644 --- a/tfjs-core/src/ops/avg_pool.ts +++ b/tfjs-core/src/ops/avg_pool.ts @@ -72,16 +72,8 @@ function avgPool_( util.assert( x4D.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${x4D.rank}.`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in avgPool: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('avgPool', pad, dimRoundingMode); const inputs: AvgPoolInputs = {x: x4D}; - const attrs: AvgPoolAttrs = {filterSize, strides, pad, dimRoundingMode}; // tslint:disable-next-line: no-unnecessary-type-assertion diff --git a/tfjs-core/src/ops/avg_pool_3d.ts b/tfjs-core/src/ops/avg_pool_3d.ts index ce09ab9475e..215969ff168 100644 --- a/tfjs-core/src/ops/avg_pool_3d.ts +++ b/tfjs-core/src/ops/avg_pool_3d.ts @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {checkPadOnDimRoundingMode} from './conv_util'; import {cast} from './cast'; import {op} from './operation'; import {reshape} from './reshape'; @@ -85,16 +86,8 @@ function avgPool3d_( dataFormat === 'NDHWC', () => `Error in avgPool3d: Only NDHWC is currently supported, ` + `but got dataFormat of ${dataFormat}`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in avgPool3d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + checkPadOnDimRoundingMode('avgPool3d', pad, dimRoundingMode); const inputs: AvgPool3DInputs = {x: x5D}; - const attrs: AvgPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat}; diff --git a/tfjs-core/src/ops/avg_pool_3d_grad.ts b/tfjs-core/src/ops/avg_pool_3d_grad.ts index fb7038fdea8..e633c927a07 100644 --- a/tfjs-core/src/ops/avg_pool_3d_grad.ts +++ b/tfjs-core/src/ops/avg_pool_3d_grad.ts @@ -25,6 +25,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {checkPadOnDimRoundingMode} from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -77,16 +78,8 @@ function avgPool3dGrad_( input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` + `${input5D.rank}.`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in avgPool3dGrad: pad must be an integer when ` + - `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode); const inputs: AvgPool3DGradInputs = {dy: dy5D, input: input5D}; - const attrs: AvgPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode}; // tslint:disable-next-line: no-unnecessary-type-assertion diff --git a/tfjs-core/src/ops/avg_pool_3d_test.ts b/tfjs-core/src/ops/avg_pool_3d_test.ts index c14321a44d5..308eca25042 100644 --- a/tfjs-core/src/ops/avg_pool_3d_test.ts +++ b/tfjs-core/src/ops/avg_pool_3d_test.ts @@ -157,7 +157,15 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => { expect(() => tf.avgPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError(); }); - it('throws when dimRoundingMode is set and pad is not a number', async () => { + it('throws when dimRoundingMode is set and pad is same', async () => { + const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); + const pad = 'same'; + const dimRoundingMode = 'round'; + + expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', async () => { const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); const pad = 'valid'; const dimRoundingMode = 'round'; @@ -165,6 +173,15 @@ describeWithFlags('avgPool3d', ALL_ENVS, () => { expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); }); + it('throws when dimRoundingMode is set and pad is a non-integer number', + async () => { + const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); + const pad = 1.2; + const dimRoundingMode = 'round'; + + expect(() => tf.avgPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.avgPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError(); }); diff --git a/tfjs-core/src/ops/avg_pool_test.ts b/tfjs-core/src/ops/avg_pool_test.ts index d49dd53ebe0..a4c0a21fb66 100644 --- a/tfjs-core/src/ops/avg_pool_test.ts +++ b/tfjs-core/src/ops/avg_pool_test.ts @@ -104,6 +104,14 @@ describeWithFlags('avgPool', ALL_ENVS, () => { expectArraysClose(await result.data(), [2.5, 3, 4, 4.5, 5.5, 6, 7, 7.5]); }); + it('x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.avgPool(x, 2, 3, 1); + + expect(result.shape).toEqual([1, 1, 3]); + }); + it('x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', () => { // Feed forward. const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); @@ -112,6 +120,30 @@ describeWithFlags('avgPool', ALL_ENVS, () => { expect(result.shape).toEqual([2, 2, 3]); }); + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.avgPool(x, 2, 3, 1, 'floor'); + + expect(result.shape).toEqual([1, 1, 3]); + }); + + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.avgPool(x, 2, 3, 1, 'round'); + + expect(result.shape).toEqual([2, 2, 3]); + }); + + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.avgPool(x, 2, 3, 1, 'ceil'); + + expect(result.shape).toEqual([2, 2, 3]); + }); + it('gradient x=[1,1,1] f=[1,1] s=1 [0] => [0]', async () => { const x = tf.tensor3d([0], [1, 1, 1]); const dy = tf.tensor3d([0], [1, 1, 1]); @@ -178,7 +210,16 @@ describeWithFlags('avgPool', ALL_ENVS, () => { ]); }); - it('throws when dimRoundingMode is set and pad is not a number', () => { + it('throws when dimRoundingMode is set and pad is same', () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = 'same'; + const dimRoundingMode = 'round'; + + expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); const pad = 'valid'; @@ -187,6 +228,28 @@ describeWithFlags('avgPool', ALL_ENVS, () => { expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); }); + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = 1.2; + const dimRoundingMode = 'round'; + + expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const dimRoundingMode = 'round'; + + expect(() => tf.avgPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.avgPool({} as tf.Tensor3D, 2, 1, 'valid')) .toThrowError(/Argument 'x' passed to 'avgPool' must be a Tensor/); diff --git a/tfjs-core/src/ops/conv1d.ts b/tfjs-core/src/ops/conv1d.ts index 6a2e1315a3f..4a04106870a 100644 --- a/tfjs-core/src/ops/conv1d.ts +++ b/tfjs-core/src/ops/conv1d.ts @@ -74,13 +74,7 @@ function conv1d_( $filter.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ` + `${$filter.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in conv1d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('conv1d', pad, dimRoundingMode); util.assert( x3D.shape[2] === $filter.shape[1], () => `Error in conv1d: depth of input (${x3D.shape[2]}) must match ` + diff --git a/tfjs-core/src/ops/conv1d_test.ts b/tfjs-core/src/ops/conv1d_test.ts index c613d9e8665..d1eef74dbe7 100644 --- a/tfjs-core/src/ops/conv1d_test.ts +++ b/tfjs-core/src/ops/conv1d_test.ts @@ -134,6 +134,90 @@ describeWithFlags('conv1d', ALL_ENVS, () => { expectArraysClose(await result.data(), await expectedResult.data()); }); + it('throws when dimRoundingMode is set and pad is same', () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = 'same'; + const stride = 1; + const dataFormat = 'NWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv1d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = 'valid'; + const stride = 1; + const dataFormat = 'NWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv1d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = 1.2; + const stride = 1; + const dataFormat = 'NWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv1d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 1; + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const stride = 1; + const dataFormat = 'NWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.tensor3d([3], [fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv1d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + it('TensorLike', async () => { const pad = 'same'; const stride = 1; diff --git a/tfjs-core/src/ops/conv2d.ts b/tfjs-core/src/ops/conv2d.ts index 8d1c8de086b..97e0d293ea7 100644 --- a/tfjs-core/src/ops/conv2d.ts +++ b/tfjs-core/src/ops/conv2d.ts @@ -84,13 +84,7 @@ function conv2d_( $filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` + `${$filter.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in conv2d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode); const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; util.assert( inDepth === $filter.shape[2], diff --git a/tfjs-core/src/ops/conv2d_backprop_filter.ts b/tfjs-core/src/ops/conv2d_backprop_filter.ts index 36b47159a4c..97d13a32330 100644 --- a/tfjs-core/src/ops/conv2d_backprop_filter.ts +++ b/tfjs-core/src/ops/conv2d_backprop_filter.ts @@ -81,13 +81,7 @@ function conv2DBackpropFilter_( outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` + `match output depth for filter (${filterShape[3]}).`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in conv2dDerFilter: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode); const inputs: Conv2DBackpropFilterInputs = {x: x4D, dy: dy4D}; const attrs: Conv2DBackpropFilterAttrs = {strides, pad, dataFormat, dimRoundingMode, filterShape}; diff --git a/tfjs-core/src/ops/conv2d_backprop_input.ts b/tfjs-core/src/ops/conv2d_backprop_input.ts index 82d33bf4d5a..827efe019dc 100644 --- a/tfjs-core/src/ops/conv2d_backprop_input.ts +++ b/tfjs-core/src/ops/conv2d_backprop_input.ts @@ -92,15 +92,8 @@ function conv2DBackpropInput_( outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` + `match output depth for filter ${filter.shape[3]}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in conv2dDerInput: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode); const inputs: Conv2DBackpropInputInputs = {dy: dy4D, filter}; - const attrs: Conv2DBackpropInputAttrs = {strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D}; diff --git a/tfjs-core/src/ops/conv2d_test.ts b/tfjs-core/src/ops/conv2d_test.ts index e2264e135fa..acd82c3aa85 100644 --- a/tfjs-core/src/ops/conv2d_test.ts +++ b/tfjs-core/src/ops/conv2d_test.ts @@ -543,7 +543,27 @@ describeWithFlags('conv2d', ALL_ENVS, () => { expect(() => tf.conv2d(x, w, stride, pad, dataFormat)).toThrowError(); }); - it('throws when dimRoundingMode is set and pad is not a number', () => { + it('throws when dimRoundingMode is set and pad is same', () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 2; + const pad = 'same'; + const stride = 1; + const dataFormat = 'NHWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.randomNormal([fSize, fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv2d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { const inputDepth = 1; const inputShape: [number, number, number] = [2, 2, inputDepth]; const outputDepth = 1; @@ -558,11 +578,56 @@ describeWithFlags('conv2d', ALL_ENVS, () => { const w = tf.randomNormal([fSize, fSize, inputDepth, outputDepth]); expect( - () => - tf.conv2d(x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + () => tf.conv2d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 2; + const pad = 1.2; + const stride = 1; + const dataFormat = 'NHWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = tf.randomNormal([fSize, fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv2d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) .toThrowError(); }); + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const outputDepth = 1; + const fSize = 2; + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const stride = 1; + const dataFormat = 'NHWC'; + const dilation = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor3d([1, 2, 3, 4], inputShape); + const w = + tf.randomNormal([fSize, fSize, inputDepth, outputDepth]); + + expect( + () => tf.conv2d( + x, w, stride, pad, dataFormat, dilation, dimRoundingMode)) + .toThrowError(); + }); + it('throws when both stride and dilation are greater than 1', () => { const inputDepth = 1; const inputShape: [number, number, number] = [2, 2, inputDepth]; diff --git a/tfjs-core/src/ops/conv2d_transpose_test.ts b/tfjs-core/src/ops/conv2d_transpose_test.ts index 7ce210b4c94..e36ec475011 100644 --- a/tfjs-core/src/ops/conv2d_transpose_test.ts +++ b/tfjs-core/src/ops/conv2d_transpose_test.ts @@ -136,6 +136,120 @@ describeWithFlags('conv2dTranspose', ALL_ENVS, () => { expectArraysClose(await result.data(), expected); }); + it('throws when dimRoundingMode is set and pad is same', async () => { + const origInputDepth = 1; + const origOutputDepth = 4; + const inputShape: [number, number, number, number] = + [1, 2, 2, origOutputDepth]; + const fSize = 2; + const origPad = 'same'; + const origStride = 2; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 1.24, 1.66, 0.9, 1.39, 0.16, 0.27, 0.42, 0.61, 0.04, 0.17, 0.34, 0.28, + 0., 0.06, 0.14, 0.24 + ], + inputShape); + const w = tf.tensor4d( + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], + [fSize, fSize, origInputDepth, origOutputDepth]); + + expect( + () => tf.conv2dTranspose( + x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', async () => { + const origInputDepth = 1; + const origOutputDepth = 4; + const inputShape: [number, number, number, number] = + [1, 2, 2, origOutputDepth]; + const fSize = 2; + const origPad = 'valid'; + const origStride = 2; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 1.24, 1.66, 0.9, 1.39, 0.16, 0.27, 0.42, 0.61, 0.04, 0.17, 0.34, 0.28, + 0., 0.06, 0.14, 0.24 + ], + inputShape); + const w = tf.tensor4d( + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], + [fSize, fSize, origInputDepth, origOutputDepth]); + + expect( + () => tf.conv2dTranspose( + x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is a non-integer number', + async () => { + const origInputDepth = 1; + const origOutputDepth = 4; + const inputShape: [number, number, number, number] = + [1, 2, 2, origOutputDepth]; + const fSize = 2; + const origPad = 1.2; + const origStride = 2; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 1.24, 1.66, 0.9, 1.39, 0.16, 0.27, 0.42, 0.61, 0.04, 0.17, 0.34, + 0.28, 0., 0.06, 0.14, 0.24 + ], + inputShape); + const w = tf.tensor4d( + [ + 0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15. + ], + [fSize, fSize, origInputDepth, origOutputDepth]); + + expect( + () => tf.conv2dTranspose( + x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + async () => { + const origInputDepth = 1; + const origOutputDepth = 4; + const inputShape: [number, number, number, number] = + [1, 2, 2, origOutputDepth]; + const fSize = 2; + const origPad = [[0, 0], [0, 1.1], [0, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const origStride = 2; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 1.24, 1.66, 0.9, 1.39, 0.16, 0.27, 0.42, 0.61, 0.04, 0.17, 0.34, + 0.28, 0., 0.06, 0.14, 0.24 + ], + inputShape); + const w = tf.tensor4d( + [ + 0., 1., 2., 3., 4., 5., 6., 7., 8., + 9., 10., 11., 12., 13., 14., 15. + ], + [fSize, fSize, origInputDepth, origOutputDepth]); + + expect( + () => tf.conv2dTranspose( + x, w, [1, 3, 3, 1], origStride, origPad, dimRoundingMode)) + .toThrowError(); + }); + // Reference (Python) TensorFlow code: // // ```py diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index 840ddc75b16..efbaebfcbf8 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -15,6 +15,8 @@ * ============================================================================= */ +import * as util from '../util'; + type PadType = 'SAME'|'VALID'|'NUMBER'|'EXPLICIT'; // For NHWC should be in the following form: @@ -597,3 +599,45 @@ export function convertConv2DDataFormat(dataFormat: 'NHWC'|'NCHW'): throw new Error(`Unknown dataFormat ${dataFormat}`); } } + +/** + * Check validity of pad when using dimRoundingMode. + * @param opDesc A string of op description + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid` output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_docs/python/tf/nn/convolution]( + * https://www.tensorflow.org/api_docs/python/tf/nn/convolution) + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is + * provided, it will default to truncate. + * @throws unknown padding parameter + */ +export function checkPadOnDimRoundingMode( + opDesc: string, pad: 'valid'|'same'|number|ExplicitPadding, + dimRoundingMode?: 'floor'|'round'|'ceil') { + if (dimRoundingMode != null) { + if (typeof pad === 'string') { + throw Error( + `Error in ${opDesc}: pad must be an integer when using ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } else if (typeof pad === 'number') { + util.assert( + util.isInt(pad), + () => `Error in ${opDesc}: pad must be an integer when using ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); + } else if (typeof pad === 'object') { + (pad as ExplicitPadding).forEach(p => {p.forEach(v =>{ + util.assert( + util.isInt(v), + () => `Error in ${opDesc}: pad must be an integer when using ` + + `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`); + }); + }); + } else { + throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`); + } + } +} diff --git a/tfjs-core/src/ops/depthwise_conv2d.ts b/tfjs-core/src/ops/depthwise_conv2d.ts index 87f52a86567..a7044c56241 100644 --- a/tfjs-core/src/ops/depthwise_conv2d.ts +++ b/tfjs-core/src/ops/depthwise_conv2d.ts @@ -23,7 +23,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {ExplicitPadding} from './conv_util'; +import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -75,7 +75,7 @@ import {reshape} from './reshape'; function depthwiseConv2d_( x: T|TensorLike, filter: Tensor4D|TensorLike, strides: [number, number]|number, - pad: 'valid'|'same'|number|ExplicitPadding, + pad: 'valid'|'same'|number|conv_util.ExplicitPadding, dataFormat: 'NHWC'|'NCHW' = 'NHWC', dilations: [number, number]|number = [1, 1], dimRoundingMode?: 'floor'|'round'|'ceil'): T { @@ -102,14 +102,7 @@ function depthwiseConv2d_( () => `Error in depthwiseConv2d: number of input channels ` + `(${x4D.shape[3]}) must match the inChannels dimension in ` + `filter ${$filter.shape[2]}.`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in depthwiseConv2d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode); const inputs: DepthwiseConv2dNativeInputs = {x: x4D, filter: $filter}; const attrs: DepthwiseConv2dNativeAttrs = {strides, pad, dataFormat, dilations, dimRoundingMode}; diff --git a/tfjs-core/src/ops/depthwise_conv2d_test.ts b/tfjs-core/src/ops/depthwise_conv2d_test.ts index dc31b7bb1b7..aeaf433c20d 100644 --- a/tfjs-core/src/ops/depthwise_conv2d_test.ts +++ b/tfjs-core/src/ops/depthwise_conv2d_test.ts @@ -803,6 +803,106 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => { expect(() => tf.depthwiseConv2d(x, w, stride, pad)).toThrowError(errRegex); }); + it('throws when dimRoundingMode is set and pad is same', () => { + const fSize = 2; + const pad = 'same'; + const stride = 1; + const chMul = 1; + const inDepth = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [0.303873, 0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + expect( + () => tf.depthwiseConv2d( + x, w, stride, pad, 'NHWC', 1, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { + const fSize = 2; + const pad = 'valid'; + const stride = 1; + const chMul = 1; + const inDepth = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [0.303873, 0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + expect( + () => tf.depthwiseConv2d( + x, w, stride, pad, 'NHWC', 1, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const fSize = 2; + const pad = 1.2; + const stride = 1; + const chMul = 1; + const inDepth = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [0.303873, 0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + expect( + () => tf.depthwiseConv2d( + x, w, stride, pad, 'NHWC', 1, dimRoundingMode)) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const fSize = 2; + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const stride = 1; + const chMul = 1; + const inDepth = 1; + const dimRoundingMode = 'round'; + + const x = tf.tensor4d( + [ + 0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641, + 0.0741907, 0.409265, 0.351377 + ], + [1, 3, 3, inDepth]); + const w = tf.tensor4d( + [0.303873, 0.229223, 0.144333, 0.803373], + [fSize, fSize, inDepth, chMul], + ); + expect( + () => tf.depthwiseConv2d( + x, w, stride, pad, 'NHWC', 1, dimRoundingMode)) + .toThrowError(); + }); + it('accepts a tensor-like object', async () => { const pad = 'valid'; const stride = 1; diff --git a/tfjs-core/src/ops/fused/conv2d.ts b/tfjs-core/src/ops/fused/conv2d.ts index 1fb833df4d0..af7f5db69a1 100644 --- a/tfjs-core/src/ops/fused/conv2d.ts +++ b/tfjs-core/src/ops/fused/conv2d.ts @@ -149,13 +149,7 @@ function fusedConv2d_({ $filter.rank === 4, () => `Error in fused conv2d: filter must be rank 4, but got rank ` + `${$filter.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in fused conv2d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('fused conv2d', pad, dimRoundingMode); util.assert( x4D.shape[3] === $filter.shape[2], () => `Error in conv2d: depth of input (${x4D.shape[3]}) must match ` + diff --git a/tfjs-core/src/ops/fused/depthwise_conv2d.ts b/tfjs-core/src/ops/fused/depthwise_conv2d.ts index 21f96c7fae8..24d3f136a08 100644 --- a/tfjs-core/src/ops/fused/depthwise_conv2d.ts +++ b/tfjs-core/src/ops/fused/depthwise_conv2d.ts @@ -154,14 +154,8 @@ function fusedDepthwiseConv2d_({ () => 'Error in fused depthwiseConv2d: Either strides or dilations must ' + `be 1. Got strides ${strides} and dilations '${dilations}'`); - - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in fused depthwiseConv2d: pad must be an integer when ` + - `using dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode( + 'fused depthwiseConv2d', pad, dimRoundingMode); const convInfo = conv_util.computeConv2DInfo( x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); diff --git a/tfjs-core/src/ops/fused/fused_conv2d_test.ts b/tfjs-core/src/ops/fused/fused_conv2d_test.ts index 112e4da2f01..781616fc2fd 100644 --- a/tfjs-core/src/ops/fused/fused_conv2d_test.ts +++ b/tfjs-core/src/ops/fused/fused_conv2d_test.ts @@ -610,6 +610,154 @@ describeWithFlags('fused conv2d', ALL_ENVS, () => { ])); }); + it('throws when dimRoundingMode is set and pad is same', () => { + const inputDepth = 16; + const xSize = 8; + const inputShape: [number, number, number, number] = + [1, xSize, xSize, inputDepth]; + const outputDepth = 8; + const fSize = 3; + const pad = 'same'; + const stride: [number, number] = [2, 2]; + + const inputs = generateCaseInputs( + 1 * xSize * xSize * inputDepth, + fSize * fSize * inputDepth * outputDepth); + const x = tf.tensor4d(inputs.input, inputShape); + const w = + tf.tensor4d(inputs.filter, [fSize, fSize, inputDepth, outputDepth]); + const bias = tf.tensor1d([1, 4, 2, 3, 9, 6, 5, 8]); + const leakyreluAlpha = 0.3; + + expect( + () => tf.fused.conv2d( + { + x, + filter: w, + strides: stride, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + activation: 'leakyrelu', + leakyreluAlpha, + bias, + dimRoundingMode: 'round' + })) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { + const inputDepth = 16; + const xSize = 8; + const inputShape: [number, number, number, number] = + [1, xSize, xSize, inputDepth]; + const outputDepth = 8; + const fSize = 3; + const pad = 'valid'; + const stride: [number, number] = [2, 2]; + + const inputs = generateCaseInputs( + 1 * xSize * xSize * inputDepth, + fSize * fSize * inputDepth * outputDepth); + const x = tf.tensor4d(inputs.input, inputShape); + const w = + tf.tensor4d(inputs.filter, [fSize, fSize, inputDepth, outputDepth]); + const bias = tf.tensor1d([1, 4, 2, 3, 9, 6, 5, 8]); + const leakyreluAlpha = 0.3; + + expect( + () => tf.fused.conv2d( + { + x, + filter: w, + strides: stride, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + activation: 'leakyrelu', + leakyreluAlpha, + bias, + dimRoundingMode: 'round' + })) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const inputDepth = 16; + const xSize = 8; + const inputShape: [number, number, number, number] = + [1, xSize, xSize, inputDepth]; + const outputDepth = 8; + const fSize = 3; + const pad = 1.2; + const stride: [number, number] = [2, 2]; + + const inputs = generateCaseInputs( + 1 * xSize * xSize * inputDepth, + fSize * fSize * inputDepth * outputDepth); + const x = tf.tensor4d(inputs.input, inputShape); + const w = + tf.tensor4d(inputs.filter, [fSize, fSize, inputDepth, outputDepth]); + const bias = tf.tensor1d([1, 4, 2, 3, 9, 6, 5, 8]); + const leakyreluAlpha = 0.3; + + expect( + () => tf.fused.conv2d( + { + x, + filter: w, + strides: stride, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + activation: 'leakyrelu', + leakyreluAlpha, + bias, + dimRoundingMode: 'round' + })) + .toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const inputDepth = 16; + const xSize = 8; + const inputShape: [number, number, number, number] = + [1, xSize, xSize, inputDepth]; + const outputDepth = 8; + const fSize = 3; + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const stride: [number, number] = [2, 2]; + + const inputs = generateCaseInputs( + 1 * xSize * xSize * inputDepth, + fSize * fSize * inputDepth * outputDepth); + const x = tf.tensor4d(inputs.input, inputShape); + const w = + tf.tensor4d(inputs.filter, [fSize, fSize, inputDepth, outputDepth]); + const bias = tf.tensor1d([1, 4, 2, 3, 9, 6, 5, 8]); + const leakyreluAlpha = 0.3; + + expect( + () => tf.fused.conv2d( + { + x, + filter: w, + strides: stride, + pad, + dataFormat: 'NHWC', + dilations: [1, 1], + activation: 'leakyrelu', + leakyreluAlpha, + bias, + dimRoundingMode: 'round' + })) + .toThrowError(); + }); + it('basic with bias', async () => { const inputDepth = 2; const inShape: [number, number, number, number] = [2, 2, 2, inputDepth]; diff --git a/tfjs-core/src/ops/max_pool.ts b/tfjs-core/src/ops/max_pool.ts index ac07edebd62..7240004b4f9 100644 --- a/tfjs-core/src/ops/max_pool.ts +++ b/tfjs-core/src/ops/max_pool.ts @@ -75,15 +75,8 @@ function maxPool_( conv_util.eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPool: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('maxPool', pad, dimRoundingMode); const inputs: MaxPoolInputs = {x: x4D}; - const attrs: MaxPoolAttrs = {filterSize, strides, pad, dimRoundingMode}; // tslint:disable-next-line: no-unnecessary-type-assertion diff --git a/tfjs-core/src/ops/max_pool_3d.ts b/tfjs-core/src/ops/max_pool_3d.ts index 443b6464a6b..cf21bc0b3e5 100644 --- a/tfjs-core/src/ops/max_pool_3d.ts +++ b/tfjs-core/src/ops/max_pool_3d.ts @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {checkPadOnDimRoundingMode} from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -83,15 +84,8 @@ function maxPool3d_( dataFormat === 'NDHWC', () => `Error in maxPool3d: Only NDHWC is currently supported, ` + `but got dataFormat of ${dataFormat}`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPool3d: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + checkPadOnDimRoundingMode('maxPool3d', pad, dimRoundingMode); const inputs: MaxPool3DInputs = {x: x5D}; - const attrs: MaxPool3DAttrs = {filterSize, strides, pad, dimRoundingMode, dataFormat}; diff --git a/tfjs-core/src/ops/max_pool_3d_grad.ts b/tfjs-core/src/ops/max_pool_3d_grad.ts index e1fe11e1748..b156179d551 100644 --- a/tfjs-core/src/ops/max_pool_3d_grad.ts +++ b/tfjs-core/src/ops/max_pool_3d_grad.ts @@ -24,6 +24,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; +import {checkPadOnDimRoundingMode} from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -87,16 +88,9 @@ function maxPool3dGrad_( output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` + `${output5D.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPool3dGrad: pad must be an integer when ` + - `using, dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode); const inputs: MaxPool3DGradInputs = {dy: dy5D, input: input5D, output: output5D}; - const attrs: MaxPool3DGradAttrs = {filterSize, strides, pad, dimRoundingMode}; // tslint:disable-next-line: no-unnecessary-type-assertion diff --git a/tfjs-core/src/ops/max_pool_3d_test.ts b/tfjs-core/src/ops/max_pool_3d_test.ts index 9bca4012ff8..6fbb0472108 100644 --- a/tfjs-core/src/ops/max_pool_3d_test.ts +++ b/tfjs-core/src/ops/max_pool_3d_test.ts @@ -154,7 +154,15 @@ describeWithFlags('maxPool3d', ALL_ENVS, () => { expect(() => tf.maxPool3d(x as tf.Tensor5D, 2, 1, 'valid')).toThrowError(); }); - it('throws when dimRoundingMode is set and pad is not a number', async () => { + it('throws when dimRoundingMode is set and pad is same', () => { + const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); + const pad = 'same'; + const dimRoundingMode = 'round'; + + expect(() => tf.maxPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); const pad = 'valid'; const dimRoundingMode = 'round'; @@ -162,6 +170,15 @@ describeWithFlags('maxPool3d', ALL_ENVS, () => { expect(() => tf.maxPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); }); + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const x = tf.tensor5d([1], [1, 1, 1, 1, 1]); + const pad = 1.2; + const dimRoundingMode = 'round'; + + expect(() => tf.maxPool3d(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.maxPool3d({} as tf.Tensor5D, 2, 1, 'valid')).toThrowError(); }); diff --git a/tfjs-core/src/ops/max_pool_grad.ts b/tfjs-core/src/ops/max_pool_grad.ts index 485ed38428c..7dfa1d313a7 100644 --- a/tfjs-core/src/ops/max_pool_grad.ts +++ b/tfjs-core/src/ops/max_pool_grad.ts @@ -24,7 +24,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {ExplicitPadding} from './conv_util'; +import * as conv_util from './conv_util'; import {op} from './operation'; /** @@ -52,7 +52,7 @@ function maxPoolGrad_( dy: Tensor4D|TensorLike, input: Tensor4D|TensorLike, output: Tensor4D|TensorLike, filterSize: [number, number]|number, strides: [number, number]|number, - pad: 'valid'|'same'|number|ExplicitPadding, + pad: 'valid'|'same'|number|conv_util.ExplicitPadding, dimRoundingMode?: 'floor'|'round'|'ceil'): Tensor4D { const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad'); const $input = convertToTensor(input, 'input', 'maxPoolGrad'); @@ -71,15 +71,8 @@ function maxPoolGrad_( $input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` + `${$input.rank}.`); - if (dimRoundingMode != null) { - util.assert( - util.isInt(pad as number), - () => `Error in maxPoolGrad: pad must be an integer when using, ` + - `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); - } - + conv_util.checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode); const inputs: MaxPoolGradInputs = {dy: $dy, input: $input, output: $output}; - const attrs: MaxPoolGradAttrs = {filterSize, strides, pad, dimRoundingMode}; // tslint:disable-next-line: no-unnecessary-type-assertion diff --git a/tfjs-core/src/ops/max_pool_test.ts b/tfjs-core/src/ops/max_pool_test.ts index b7461b8ca4b..2827f4211f5 100644 --- a/tfjs-core/src/ops/max_pool_test.ts +++ b/tfjs-core/src/ops/max_pool_test.ts @@ -111,6 +111,14 @@ describeWithFlags('maxPool', ALL_ENVS, () => { expectArraysClose(await result.data(), [4, 4, 4, 4]); }); + it('x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.maxPool(x, 2, 3, 1); + + expect(result.shape).toEqual([1, 1, 3]); + }); + it('x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', () => { // Feed forward. const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); @@ -119,6 +127,30 @@ describeWithFlags('maxPool', ALL_ENVS, () => { expect(result.shape).toEqual([2, 2, 3]); }); + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.maxPool(x, 2, 3, 1, 'floor'); + + expect(result.shape).toEqual([1, 1, 3]); + }); + + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.maxPool(x, 2, 3, 1, 'round'); + + expect(result.shape).toEqual([2, 2, 3]); + }); + + it('x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', () => { + // Feed forward. + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]); + const result = tf.maxPool(x, 2, 3, 1, 'ceil'); + + expect(result.shape).toEqual([2, 2, 3]); + }); + it('throws when x is not rank 3', () => { // tslint:disable-next-line:no-any const x: any = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8, 9], [3, 3]); @@ -126,7 +158,16 @@ describeWithFlags('maxPool', ALL_ENVS, () => { expect(() => tf.maxPool(x, 2, 1, 0)).toThrowError(); }); - it('throws when dimRoundingMode is set and pad is not a number', () => { + it('throws when dimRoundingMode is set and pad is same', () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = 'same'; + const dimRoundingMode = 'round'; + + expect(() => tf.maxPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is valid', () => { const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); const pad = 'valid'; @@ -135,6 +176,28 @@ describeWithFlags('maxPool', ALL_ENVS, () => { expect(() => tf.maxPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); }); + it('throws when dimRoundingMode is set and pad is a non-integer number', + () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = 1.2; + const dimRoundingMode = 'round'; + + expect(() => tf.maxPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + + it('throws when dimRoundingMode is set and pad is explicit by non-integer ' + + 'number', + () => { + const x = tf.tensor3d([1, 2, 3, 4], [2, 2, 1]); + + const pad = [[0, 0], [0, 2.1], [1, 1], [0, 0]] as + tf.backend_util.ExplicitPadding; + const dimRoundingMode = 'round'; + + expect(() => tf.maxPool(x, 2, 1, pad, dimRoundingMode)).toThrowError(); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.maxPool({} as tf.Tensor3D, 2, 1, 'valid')) .toThrowError(/Argument 'x' passed to 'maxPool' must be a Tensor/); diff --git a/tfjs-core/src/ops/pool.ts b/tfjs-core/src/ops/pool.ts index ee61277d84e..5c5098e70b7 100644 --- a/tfjs-core/src/ops/pool.ts +++ b/tfjs-core/src/ops/pool.ts @@ -51,6 +51,8 @@ import {spaceToBatchND} from './space_to_batch_nd'; * 1, then all values of `strides` must be 1. * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. If none is + * provided, it will default to truncate. * * @doc {heading: 'Operations', subheading: 'Convolution'} */ @@ -58,7 +60,8 @@ function pool_( input: T|TensorLike, windowShape: [number, number]|number, poolingType: 'avg'|'max', pad: 'valid'|'same'|number|conv_util.ExplicitPadding, - dilations?: [number, number]|number, strides?: [number, number]|number) { + dilations?: [number, number]|number, strides?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil') { if (dilations == null) { dilations = [1, 1]; } @@ -109,8 +112,10 @@ function pool_( isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding); const forwardOp = poolingType === 'avg' ? - () => avgPool(convertedX, windowShape, strides, convertedPad) : - () => maxPool(convertedX, windowShape, strides, convertedPad); + () => avgPool(convertedX, windowShape, strides, convertedPad, + dimRoundingMode) : + () => maxPool(convertedX, windowShape, strides, convertedPad, + dimRoundingMode); const y = forwardOp(); const res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops); diff --git a/tfjs-core/src/ops/pool_test.ts b/tfjs-core/src/ops/pool_test.ts index 5ffeeec6903..69294dc45b4 100644 --- a/tfjs-core/src/ops/pool_test.ts +++ b/tfjs-core/src/ops/pool_test.ts @@ -97,6 +97,70 @@ describeWithFlags('pool', ALL_ENVS, () => { expectArraysClose(await result.data(), [5, 5, 8, 8, 8, 8, 8, 8]); }); + it('max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit defualt dimRoundingMode', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = + tf.pool(x, windowShape, 'max', padding, dilationRate, strides); + expect(result.shape).toEqual([2, 1, 1]); + }); + + it('max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=floor', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'max', padding, dilationRate, + strides, 'floor'); + expect(result.shape).toEqual([2, 1, 1]); + }); + + it('max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=round', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'max', padding, dilationRate, + strides, 'round'); + expect(result.shape).toEqual([2, 2, 1]); + }); + + it('max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=ceil', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'max', padding, dilationRate, + strides, 'ceil'); + expect(result.shape).toEqual([3, 2, 1]); + }); + it('max x=[2,2,3] f=[1,1] s=2 p=1 fractional outputs default rounding', async () => { // Feed forward. @@ -191,6 +255,70 @@ describeWithFlags('pool', ALL_ENVS, () => { await result.data(), [2.5, 3, 4, 4.5, 5.5, 6, 7, 7.5]); }); + it('avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit defualt dimRoundingMode', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = + tf.pool(x, windowShape, 'avg', padding, dilationRate, strides); + expect(result.shape).toEqual([2, 1, 1]); + }); + + it('avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=floor', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'avg', padding, dilationRate, + strides, 'floor'); + expect(result.shape).toEqual([2, 1, 1]); + }); + + it('avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=round', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'avg', padding, dilationRate, + strides, 'round'); + expect(result.shape).toEqual([2, 2, 1]); + }); + + it('avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=ceil', + async () => { + // Feed forward. + const x = tf.tensor3d([0, 1, 2, 3, 4, 5, 6, 7, 8], [3, 3, 1]); + + const windowShape = 3; + const padding = + [[0, 0], [2, 2], [1, 1], [0, 0]] as tf.backend_util.ExplicitPadding; + const dilationRate: number = undefined; + const strides = 3; + + const result = tf.pool(x, windowShape, 'avg', padding, dilationRate, + strides, 'ceil'); + expect(result.shape).toEqual([3, 2, 1]); + }); + it('avg x=[2,2,3] f=[1,1] s=2 p=1 fractional outputs default rounding', async () => { // Feed forward. diff --git a/tfjs-core/src/public/chained_ops/pool.ts b/tfjs-core/src/public/chained_ops/pool.ts index 4f2da1c0fae..0906aceed8b 100644 --- a/tfjs-core/src/public/chained_ops/pool.ts +++ b/tfjs-core/src/public/chained_ops/pool.ts @@ -25,7 +25,8 @@ declare module '../../tensor' { windowShape: [number, number]|number, poolingType: 'avg'|'max', padding: 'valid'|'same'|number|ExplicitPadding, diationRate?: [number, number]|number, - strides?: [number, number]|number): T; + strides?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil'): T; } } @@ -33,7 +34,9 @@ getGlobalTensorClass().prototype.pool = function( this: T, windowShape: [number, number]|number, poolingType: 'max'|'avg', padding: 'valid'|'same'|number|ExplicitPadding, dilationRate?: [number, number]|number, - strides?: [number, number]|number): T { + strides?: [number, number]|number, + dimRoundingMode?: 'floor'|'round'|'ceil'): T { this.throwIfDisposed(); - return pool(this, windowShape, poolingType, padding, dilationRate, strides); + return pool(this, windowShape, poolingType, padding, dilationRate, strides, + dimRoundingMode); }; diff --git a/tfjs-node/src/run_tests.ts b/tfjs-node/src/run_tests.ts index 57b52c26ad2..2aa8e4b472e 100644 --- a/tfjs-node/src/run_tests.ts +++ b/tfjs-node/src/run_tests.ts @@ -66,10 +66,20 @@ const IGNORE_LIST: string[] = [ 'scatterND test-tensorflow {} should sum the duplicated indices', 'scatterND test-tensorflow {} should work for tensorLike input', // https://github.com/tensorflow/tfjs/issues/1077 + // tslint:disable-next-line:max-line-length + 'maxPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', 'maxPool test-tensorflow {} x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', + 'maxPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', + 'maxPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', + 'maxPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', // Node backend which uses TF 2.4.0 doesn't support explicit padding 'avgPool test-tensorflow {} x=[3,3,1] f=[3,3] s=1 p=explicit', + // tslint:disable-next-line:max-line-length + 'avgPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 default dimRoundingMode', 'avgPool test-tensorflow {} x=[2,2,3] f=[1,1] s=2 p=1 dimRoundingMode=floor', + 'avgPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=floor', + 'avgPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=round', + 'avgPool test-tensorflow {} x=[2,2,3] f=[2,2] s=3 p=1 dimRoundingMode=ceil', // Node backend which uses TF 2.4.0 doesn't support explicit padding 'avgPool test-tensorflow {} gradient x=[3,3,1] f=[3,3] s=1 p=explicit', // tslint:disable-next-line:max-line-length @@ -105,6 +115,22 @@ const IGNORE_LIST: string[] = [ // Node backend which uses TF 2.4.0 doesn't support explicit padding 'pool test-tensorflow {} max x=[3,3,1] f=[3,3] s=1 d=1 p=explicit', // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit defualt dimRoundingMode', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=floor', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=round', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} max x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=ceil', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit defualt dimRoundingMode', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=floor', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=round', + // tslint:disable-next-line:max-line-length + 'pool test-tensorflow {} avg x=[3,3,1] f=[3,3] s=3 d=1 p=explicit dimRoundingMode=ceil', + // tslint:disable-next-line:max-line-length 'pool test-tensorflow {} max x=[2,2,3] f=[1,1] s=2 p=1 fractional outputs default rounding', // Node backend which uses TF 2.4.0 doesn't support explicit padding 'pool test-tensorflow {} avg x=[3,3,1] f=[3,3] s=1 d=1 p=explicit',