Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CPU fusedConv2d (activation) for NCHW format #6400

Merged
merged 7 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions tfjs-backend-cpu/src/kernels/FusedConv2D.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,46 @@ export function fusedConv2D(args: {

if (bias) {
const resultOld = result;
// For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
// to the channel of the conv2d's result; if the bias is a scalar, the
// bias_add is computed as if the bias was broadcasted to the shape of the
// conv2d's result.
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
bias.shape[0] !== 1) {
// For NCHW format, if bias is a 1-D tensor, it is supposed to be aligned
// to the channel of the conv2d's result; if bias is a scalar, the
// bias_add is computed as if the bias was broadcasted to the shape of the
// conv2d's result.
const reshapedBias = reshape(
{inputs: {x: bias}, backend, attrs: {shape: [bias.shape[0], 1, 1]}});
result =
add({inputs: {a: result, b: reshapedBias}, backend}) as TensorInfo;
backend.disposeIntermediateTensorInfo(reshapedBias);
} else {
// This condition handles NHWC and NCHW (scalar case). The only other case
// for NCHW (1D case) is handled above.
result = add({inputs: {a: result, b: bias}, backend}) as TensorInfo;
}
backend.disposeIntermediateTensorInfo(resultOld);
}

if (activation) {
const resultOld = result;
result = applyActivation(
backend, result, activation, preluActivationWeights, leakyreluAlpha);
// For NCHW format, if PReLu activation weights is a 1-D tensor, it is
// supposed to be aligned with the channel of the conv2d's result. For other
// cases, whether NCHW or NHWC data format, the conv2d result is
// already aligned with the activation weights.
if (dataFormat === 'NCHW' && activation === 'prelu' &&
preluActivationWeights.shape.length === 1 &&
preluActivationWeights.shape[0] !== 1) {
const reshapedAlpha = reshape({
inputs: {x: preluActivationWeights},
backend,
attrs: {shape: [preluActivationWeights.shape[0], 1, 1]}
});
result = applyActivation(
backend, result, activation, reshapedAlpha, leakyreluAlpha);
backend.disposeIntermediateTensorInfo(reshapedAlpha);
} else {
result = applyActivation(
backend, result, activation, preluActivationWeights, leakyreluAlpha);
}
backend.disposeIntermediateTensorInfo(resultOld);
}

Expand Down
83 changes: 23 additions & 60 deletions tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util, broadcast_util, TensorInfo, util} from '@tensorflow/tfjs-core';
import {backend_util, TensorInfo, util} from '@tensorflow/tfjs-core';

// import {assertAndGetBroadcastShape} from
// '../../../tfjs-core/src/ops/broadcast_util';
Expand All @@ -41,41 +41,6 @@ type Conv2DConfig = {
activation?: backend_util.Activation
};

function fitPreluActivationWeightsIntoNhwcFormat(
alpha: TensorInfo, outputShape: [number, number, number, number],
isChannelsLast: boolean, backend: MathBackendWebGL) {
// PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
// tensor.
const alphaShape = alpha.shape;
util.assert(
alphaShape.length <= 1 || alphaShape.length === 3,
() => `WebGL conv2d only supports scalar, 1-D Tensor or 3-D ` +
`Tensor PReLU activation weights but got a tensor of ` +
`rank-${alphaShape.length}.`);
if (alphaShape.length === 1) {
const outputChannels = isChannelsLast ? outputShape[3] : outputShape[1];
util.assert(
alphaShape[0] === 1 || alphaShape[0] === outputChannels,
() => `WebGL conv2d PReLU activation weights (${alphaShape}) is ` +
`not compatible with the number of output channels ` +
`(${outputChannels}).`);
} else if (alphaShape.length === 3) {
try {
broadcast_util.assertAndGetBroadcastShape(alphaShape, outputShape);
} catch (e) {
const errMsg = `WebGL conv2d PReLU activation weights (${alphaShape}) ` +
`is not compatible with the output shape of the conv2d ` +
`(${outputShape}).`;
throw Error(errMsg);
}
if (!isChannelsLast) {
// If PReLU's activation weights is NCHW format, then convert it to NHWC.
return transpose({inputs: {x: alpha}, backend, attrs: {perm: [1, 2, 0]}});
}
}
return alpha;
}

// For 1x1 kernels that iterate through every point in the input, convolution
// can be expressed as matrix multiplication (without need for memory
// remapping).
Expand Down Expand Up @@ -103,18 +68,17 @@ export function conv2dByMatMul({
let out: TensorInfo;
const intermediates: TensorInfo[] = [];

if (preluActivationWeights != null) {
const preluActivationWeightsInNhwcFormat =
fitPreluActivationWeightsIntoNhwcFormat(
preluActivationWeights, convInfo.outShape, isChannelsLast, backend);

if (preluActivationWeightsInNhwcFormat.dataId !==
preluActivationWeights.dataId) {
// preluActivationWeightsInNhwcFormat is a new tensor, temporarily
// generated to be compatible with the following matmul computation.
intermediates.push(preluActivationWeightsInNhwcFormat);
preluActivationWeights = preluActivationWeightsInNhwcFormat;
}
if (preluActivationWeights != null && !isChannelsLast &&
preluActivationWeights.shape.length === 3) {
// If PReLU's activation weights is NCHW format, then convert it to NHWC for
// the following computation.
const preluActivationWeightsInNhwcFormat = transpose({
inputs: {x: preluActivationWeights},
backend,
attrs: {perm: [1, 2, 0]}
});
intermediates.push(preluActivationWeightsInNhwcFormat);
preluActivationWeights = preluActivationWeightsInNhwcFormat;
}

// TODO: Once reduction ops are packed, batchMatMul will always be packed
Expand Down Expand Up @@ -287,18 +251,17 @@ export function conv2dWithIm2Row({

const intermediates: TensorInfo[] = [];

if (preluActivationWeights != null) {
const preluActivationWeightsInNhwcFormat =
fitPreluActivationWeightsIntoNhwcFormat(
preluActivationWeights, convInfo.outShape, isChannelsLast, backend);

if (preluActivationWeightsInNhwcFormat.dataId !==
preluActivationWeights.dataId) {
// preluActivationWeightsInNhwcFormat is a new tensor, temporarily
// generated to be compatible with the following matmul computation.
intermediates.push(preluActivationWeightsInNhwcFormat);
preluActivationWeights = preluActivationWeightsInNhwcFormat;
}
if (preluActivationWeights != null && !isChannelsLast &&
preluActivationWeights.shape.length === 3) {
// If PReLU's activation weights is NCHW format, then convert it to NHWC for
// the following computation.
const preluActivationWeightsInNhwcFormat = transpose({
inputs: {x: preluActivationWeights},
backend,
attrs: {perm: [1, 2, 0]}
});
intermediates.push(preluActivationWeightsInNhwcFormat);
preluActivationWeights = preluActivationWeightsInNhwcFormat;
}

const xSqueezed =
Expand Down
33 changes: 33 additions & 0 deletions tfjs-core/src/ops/fused/conv2d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,39 @@ function fusedConv2d_<T extends Tensor3D|Tensor4D>({

let $preluActivationWeights: Tensor;
if (preluActivationWeights != null) {
// PReLU's activation weights could be a scalar, a 1-D tensor or a 3-D
// tensor.
const alphaShape = preluActivationWeights.shape;
util.assert(
alphaShape.length <= 1 || alphaShape.length === 3,
() => `Error in fused conv2d: only supports scalar, 1-D Tensor or ` +
`3-D Tensor PReLU activation weights but got a tensor of ` +
`rank-${alphaShape.length}.`);

if (alphaShape.length === 1) {
// Whether the data format is NCHW or NHWC, the 1-D PReLU activation
// weights tensor should be aligned with the output channels of conv2d
// result.
util.assert(
alphaShape[0] === 1 || alphaShape[0] === convInfo.outChannels,
() => `Error in fused conv2d: PReLU activation weights ` +
`(${alphaShape}) is not compatible with the number of output ` +
`channels (${convInfo.outChannels}).`);
} else if (alphaShape.length === 3) {
// Whether the data format is NCHW or NHWC, the PReLU activation weights
// tensor should has the compatible shape with the result of conv2d.
try {
broadcast_util.assertAndGetBroadcastShape(
alphaShape, convInfo.outShape);
} catch (e) {
const errMsg =
`Error in fused conv2d: PReLU activation weights (${alphaShape}) ` +
`is not compatible with the output shape of the conv2d ` +
`(${convInfo.outShape}).`;
throw Error(errMsg);
}
}

$preluActivationWeights = convertToTensor(
preluActivationWeights, 'prelu weights', 'fused conv2d');
}
Expand Down