diff --git a/tfjs-backend-cpu/src/kernels/FusedConv2D.ts b/tfjs-backend-cpu/src/kernels/FusedConv2D.ts index 29657d0a12f..2040c674da5 100644 --- a/tfjs-backend-cpu/src/kernels/FusedConv2D.ts +++ b/tfjs-backend-cpu/src/kernels/FusedConv2D.ts @@ -48,18 +48,20 @@ export function fusedConv2D(args: { if (bias) { const resultOld = result; + // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned + // to the channel of the conv2d's result; if the bias is a scalar, the + // bias_add is computed as if the bias was broadcasted to the shape of the + // conv2d's result. if (dataFormat === 'NCHW' && bias.shape.length === 1 && bias.shape[0] !== 1) { - // For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned - // to the channel of the conv2d's result; if bias is a scalar, the - // bias_add is computed as if the bias was broadcasted to the shape of the - // conv2d's result. const reshapedBias = reshape( {inputs: {x: bias}, backend, attrs: {shape: [bias.shape[0], 1, 1]}}); result = add({inputs: {a: result, b: reshapedBias}, backend}) as TensorInfo; backend.disposeIntermediateTensorInfo(reshapedBias); } else { + // This condition handles NHWC and NCHW (scalar case). The only other case + // for NCHW (1D case) is handled above. result = add({inputs: {a: result, b: bias}, backend}) as TensorInfo; } backend.disposeIntermediateTensorInfo(resultOld); @@ -67,8 +69,25 @@ export function fusedConv2D(args: { if (activation) { const resultOld = result; - result = applyActivation( - backend, result, activation, preluActivationWeights, leakyreluAlpha); + // For NCHW format, if PReLu activation weights is a 1-D tensor, it is + // supposed to be aligned with the channel of the conv2d's result. For other + // cases, whether NCHW or NHWC data format, the conv2d result is + // already aligned with the activation weights. + if (dataFormat === 'NCHW' && activation === 'prelu' && + preluActivationWeights.shape.length === 1 && + preluActivationWeights.shape[0] !== 1) { + const reshapedAlpha = reshape({ + inputs: {x: preluActivationWeights}, + backend, + attrs: {shape: [preluActivationWeights.shape[0], 1, 1]} + }); + result = applyActivation( + backend, result, activation, reshapedAlpha, leakyreluAlpha); + backend.disposeIntermediateTensorInfo(reshapedAlpha); + } else { + result = applyActivation( + backend, result, activation, preluActivationWeights, leakyreluAlpha); + } backend.disposeIntermediateTensorInfo(resultOld); } diff --git a/tfjs-backend-webgl/src/kernels/Conv2D_impl.ts b/tfjs-backend-webgl/src/kernels/Conv2D_impl.ts index 33af03a3912..a18f8196687 100644 --- a/tfjs-backend-webgl/src/kernels/Conv2D_impl.ts +++ b/tfjs-backend-webgl/src/kernels/Conv2D_impl.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, broadcast_util, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core'; // import {assertAndGetBroadcastShape} from // '../../../tfjs-core/src/ops/broadcast_util'; @@ -41,41 +41,6 @@ type Conv2DConfig = { activation?: backend_util.Activation }; -function fitPreluActivationWeightsIntoNhwcFormat( - alpha: TensorInfo, outputShape: [number, number, number, number], - isChannelsLast: boolean, backend: MathBackendWebGL) { - // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D - // tensor. - const alphaShape = alpha.shape; - util.assert( - alphaShape.length <= 1 || alphaShape.length === 3, - () => `WebGL conv2d only supports scalar, 1-D Tensor or 3-D ` + - `Tensor PReLU activation weights but got a tensor of ` + - `rank-${alphaShape.length}.`); - if (alphaShape.length === 1) { - const outputChannels = isChannelsLast ? outputShape[3] : outputShape[1]; - util.assert( - alphaShape[0] === 1 || alphaShape[0] === outputChannels, - () => `WebGL conv2d PReLU activation weights (${alphaShape}) is ` + - `not compatible with the number of output channels ` + - `(${outputChannels}).`); - } else if (alphaShape.length === 3) { - try { - broadcast_util.assertAndGetBroadcastShape(alphaShape, outputShape); - } catch (e) { - const errMsg = `WebGL conv2d PReLU activation weights (${alphaShape}) ` + - `is not compatible with the output shape of the conv2d ` + - `(${outputShape}).`; - throw Error(errMsg); - } - if (!isChannelsLast) { - // If PReLU's activation weights is NCHW format, then convert it to NHWC. - return transpose({inputs: {x: alpha}, backend, attrs: {perm: [1, 2, 0]}}); - } - } - return alpha; -} - // For 1x1 kernels that iterate through every point in the input, convolution // can be expressed as matrix multiplication (without need for memory // remapping). @@ -103,18 +68,17 @@ export function conv2dByMatMul({ let out: TensorInfo; const intermediates: TensorInfo[] = []; - if (preluActivationWeights != null) { - const preluActivationWeightsInNhwcFormat = - fitPreluActivationWeightsIntoNhwcFormat( - preluActivationWeights, convInfo.outShape, isChannelsLast, backend); - - if (preluActivationWeightsInNhwcFormat.dataId !== - preluActivationWeights.dataId) { - // preluActivationWeightsInNhwcFormat is a new tensor, temporarily - // generated to be compatible with the following matmul computation. - intermediates.push(preluActivationWeightsInNhwcFormat); - preluActivationWeights = preluActivationWeightsInNhwcFormat; - } + if (preluActivationWeights != null && !isChannelsLast && + preluActivationWeights.shape.length === 3) { + // If PReLU's activation weights is NCHW format, then convert it to NHWC for + // the following computation. + const preluActivationWeightsInNhwcFormat = transpose({ + inputs: {x: preluActivationWeights}, + backend, + attrs: {perm: [1, 2, 0]} + }); + intermediates.push(preluActivationWeightsInNhwcFormat); + preluActivationWeights = preluActivationWeightsInNhwcFormat; } // TODO: Once reduction ops are packed, batchMatMul will always be packed @@ -287,18 +251,17 @@ export function conv2dWithIm2Row({ const intermediates: TensorInfo[] = []; - if (preluActivationWeights != null) { - const preluActivationWeightsInNhwcFormat = - fitPreluActivationWeightsIntoNhwcFormat( - preluActivationWeights, convInfo.outShape, isChannelsLast, backend); - - if (preluActivationWeightsInNhwcFormat.dataId !== - preluActivationWeights.dataId) { - // preluActivationWeightsInNhwcFormat is a new tensor, temporarily - // generated to be compatible with the following matmul computation. - intermediates.push(preluActivationWeightsInNhwcFormat); - preluActivationWeights = preluActivationWeightsInNhwcFormat; - } + if (preluActivationWeights != null && !isChannelsLast && + preluActivationWeights.shape.length === 3) { + // If PReLU's activation weights is NCHW format, then convert it to NHWC for + // the following computation. + const preluActivationWeightsInNhwcFormat = transpose({ + inputs: {x: preluActivationWeights}, + backend, + attrs: {perm: [1, 2, 0]} + }); + intermediates.push(preluActivationWeightsInNhwcFormat); + preluActivationWeights = preluActivationWeightsInNhwcFormat; } const xSqueezed = diff --git a/tfjs-core/src/ops/fused/conv2d.ts b/tfjs-core/src/ops/fused/conv2d.ts index a99bdc85258..cfb1c61076c 100644 --- a/tfjs-core/src/ops/fused/conv2d.ts +++ b/tfjs-core/src/ops/fused/conv2d.ts @@ -193,6 +193,39 @@ function fusedConv2d_({ let $preluActivationWeights: Tensor; if (preluActivationWeights != null) { + // PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D + // tensor. + const alphaShape = preluActivationWeights.shape; + util.assert( + alphaShape.length <= 1 || alphaShape.length === 3, + () => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` + + `3-D Tensor PReLU activation weights but got a tensor of ` + + `rank-${alphaShape.length}.`); + + if (alphaShape.length === 1) { + // Whether the data format is NCHW or NHWC, the 1-D PReLU activation + // weights tensor should be aligned with the output channels of conv2d + // result. + util.assert( + alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels, + () => `Error in fused conv2d: PReLU activation weights ` + + `(${alphaShape}) is not compatible with the number of output ` + + `channels (${convInfo.outChannels}).`); + } else if (alphaShape.length === 3) { + // Whether the data format is NCHW or NHWC, the PReLU activation weights + // tensor should has the compatible shape with the result of conv2d. + try { + broadcast_util.assertAndGetBroadcastShape( + alphaShape, convInfo.outShape); + } catch (e) { + const errMsg = + `Error in fused conv2d: PReLU activation weights (${alphaShape}) ` + + `is not compatible with the output shape of the conv2d ` + + `(${convInfo.outShape}).`; + throw Error(errMsg); + } + } + $preluActivationWeights = convertToTensor( preluActivationWeights, 'prelu weights', 'fused conv2d'); }