Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WebGL Conv2d for NCHW format #6340

Merged
merged 21 commits into from
Apr 28, 2022
Merged

Conversation

Linchenn
Copy link
Collaborator

@Linchenn Linchenn commented Apr 20, 2022

Based on the findings from #6308 (comment), there are 3 branches to compute Conv2d on WebGL. The first branch and the second branch are problematic for NCHW format, so this PR does the following:

  1. The first branch does not consider the NCHW format, so this PR adds a condition for it to exclude computation for NCHW format. Whether this branch is supposed to support NCHW format deserves some further research.
  2. The second branch is supposed to transpose the result of matrix multiplication to be NCHW format, instead of simply reshaping it. This PR fixes it.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@Linchenn Linchenn requested a review from pyu10055 April 20, 2022 23:30
@Linchenn Linchenn requested a review from qjia7 April 22, 2022 23:31
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Linchenn Can you add a test to verify the third path to see if it works correctly?

@@ -39,7 +39,8 @@ export function conv2d(
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID') &&
$dataFormat === 'channelsLast') {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 1x1 conv2d branch,
if the data format is NHWC, it becomes xReshaped * filterReshaped with transposeA = false and transposeB = false.
If the data format is NCHW, it becomes
filterReshaped * xReshaped with transposeA = true and transposeB = false.
So you still can go to the conv2dByMatMul path with some changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your idea about the 1x1 conv2d branch! Just added some reshaping and transposing logics to support the NCHW format. Also, added some basic tests for NCHW format and a test for NCHW format with batch=2.

We already have a test of NCHW format for the third path.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @Linchenn)


tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 261 at r1 (raw file):

  const inputs: TensorInfo[] = [im2ColReshaped, w2Row];
  if (bias) {
    inputs.push(bias);

For NCHW conv2d, the bias is also in the NCHW format, it will require transpose if the matmul computation result is in NHWC format.

Copy link
Collaborator Author

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)


tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 261 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

For NCHW conv2d, the bias is also in the NCHW format, it will require transpose if the matmul computation result is in NHWC format.

Referring to bias_add op in Python involved in fusedConv2d, the bias is supposed to be a 1-D tensor aligned with the output channels.

I added asserts to check the shape of the bias to be a 1-D tensor or a scalar (the scalar could be broadcasted to be a 1-D tensor) and added tests to check conv2d computation with bias.

Copy link
Collaborator Author

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)

a discussion (no related file):
In WebGL's Conv2d_impl.ts, both conv2dByMatMul and conv2dWithIm2Row take bias and preluActivationWeights as parameters.

bias is supposed to be a 1-D tensor, so the dataFormat would not affect it. Added some tests in this PR.

preluActivationWeights, as defined in fusedConv2d, the shape could be similar with the input, which means it could have multiple dimensions and dataFormat may affect it, while the both functions do not consider NCHW format for it, which could be problematic:

@param preluActivationWeights Tensor of prelu weights to be applied as part of a prelu activation, typically the same shape as x.

In conclusion, preluActivationWeights deserves more research to check the correctness for NCHW format in WebGL, but this is unrelated to this PR, so I opened a new ISSUE to track.


@Linchenn Linchenn requested a review from pyu10055 April 28, 2022 21:15
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @Linchenn and @qjia7)

a discussion (no related file):

Previously, Linchenn wrote…

In WebGL's Conv2d_impl.ts, both conv2dByMatMul and conv2dWithIm2Row take bias and preluActivationWeights as parameters.

bias is supposed to be a 1-D tensor, so the dataFormat would not affect it. Added some tests in this PR.

preluActivationWeights, as defined in fusedConv2d, the shape could be similar with the input, which means it could have multiple dimensions and dataFormat may affect it, while the both functions do not consider NCHW format for it, which could be problematic:

@param preluActivationWeights Tensor of prelu weights to be applied as part of a prelu activation, typically the same shape as x.

In conclusion, preluActivationWeights deserves more research to check the correctness for NCHW format in WebGL, but this is unrelated to this PR, so I opened a new ISSUE to track.

thank you for expanding into bias and activation function, to fully support NCHW format, those are things we need to figure. SG to track them in a separate issue.



tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 261 at r1 (raw file):

Previously, Linchenn wrote…

Referring to bias_add op in Python involved in fusedConv2d, the bias is supposed to be a 1-D tensor aligned with the output channels.

I added asserts to check the shape of the bias to be a 1-D tensor or a scalar (the scalar could be broadcasted to be a 1-D tensor) and added tests to check conv2d computation with bias.

thank you for investigating on this, this is great information.


tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 78 at r10 (raw file):

        bias.shape.length === 0 || bias.shape[0] === convInfo.outChannels,
        () => `WebGL conv2dByMatMul bias shape (${bias.shape}) does not ` +
            `match the number of output channels (${convInfo.outChannels})`);

is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @Linchenn and @qjia7)


tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 258 at r10 (raw file):

is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 1 of 1 files at r3, 1 of 2 files at r4, 1 of 1 files at r6, 1 of 3 files at r8, 1 of 1 files at r9, 1 of 1 files at r10, all commit messages.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @Linchenn and @qjia7)

Copy link
Collaborator Author

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)

a discussion (no related file):

Previously, pyu10055 (Ping Yu) wrote…

thank you for expanding into bias and activation function, to fully support NCHW format, those are things we need to figure. SG to track them in a separate issue.

Thank you Ping!



tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 78 at r10 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.

Done.


tfjs-backend-webgl/src/kernels/Conv2D_impl.ts line 258 at r10 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.

Done.

@Linchenn Linchenn merged commit 996991e into tensorflow:master Apr 28, 2022
@Linchenn Linchenn deleted the fixConv2d branch April 28, 2022 22:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants