-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix WebGL Conv2d for NCHW format #6340
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Linchenn Can you add a test to verify the third path to see if it works correctly?
@@ -39,7 +39,8 @@ export function conv2d( | |||
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && | |||
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && | |||
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && | |||
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { | |||
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID') && | |||
$dataFormat === 'channelsLast') { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 1x1 conv2d branch,
if the data format is NHWC, it becomes xReshaped
* filterReshaped
with transposeA
= false
and transposeB
= false
.
If the data format is NCHW, it becomes
filterReshaped
* xReshaped
with transposeA
= true
and transposeB
= false
.
So you still can go to the conv2dByMatMul
path with some changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your idea about the 1x1 conv2d branch! Just added some reshaping and transposing logics to support the NCHW
format. Also, added some basic tests for NCHW
format and a test for NCHW
format with batch=2.
We already have a test of NCHW
format for the third path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @Linchenn)
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 261 at r1 (raw file):
const inputs: TensorInfo[] = [im2ColReshaped, w2Row]; if (bias) { inputs.push(bias);
For NCHW conv2d, the bias is also in the NCHW format, it will require transpose if the matmul computation result is in NHWC format.
This reverts commit e13c41e.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 261 at r1 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
For NCHW conv2d, the bias is also in the NCHW format, it will require transpose if the matmul computation result is in NHWC format.
Referring to bias_add op in Python involved in fusedConv2d, the bias is supposed to be a 1-D tensor aligned with the output channels.
I added asserts to check the shape of the bias to be a 1-D tensor or a scalar (the scalar could be broadcasted to be a 1-D tensor) and added tests to check conv2d computation with bias.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)
a discussion (no related file):
In WebGL's Conv2d_impl.ts, both conv2dByMatMul and conv2dWithIm2Row take bias
and preluActivationWeights
as parameters.
bias
is supposed to be a 1-D tensor, so the dataFormat would not affect it. Added some tests in this PR.
preluActivationWeights
, as defined in fusedConv2d, the shape could be similar with the input, which means it could have multiple dimensions and dataFormat
may affect it, while the both functions do not consider NCHW
format for it, which could be problematic:
@param preluActivationWeights Tensor of prelu weights to be applied as part of a
prelu
activation, typically the same shape asx
.
In conclusion, preluActivationWeights
deserves more research to check the correctness for NCHW
format in WebGL, but this is unrelated to this PR, so I opened a new ISSUE to track.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @Linchenn and @qjia7)
a discussion (no related file):
Previously, Linchenn wrote…
In WebGL's Conv2d_impl.ts, both conv2dByMatMul and conv2dWithIm2Row take
bias
andpreluActivationWeights
as parameters.
bias
is supposed to be a 1-D tensor, so the dataFormat would not affect it. Added some tests in this PR.
preluActivationWeights
, as defined in fusedConv2d, the shape could be similar with the input, which means it could have multiple dimensions anddataFormat
may affect it, while the both functions do not considerNCHW
format for it, which could be problematic:@param preluActivationWeights Tensor of prelu weights to be applied as part of a
prelu
activation, typically the same shape asx
.In conclusion,
preluActivationWeights
deserves more research to check the correctness forNCHW
format in WebGL, but this is unrelated to this PR, so I opened a new ISSUE to track.
thank you for expanding into bias and activation function, to fully support NCHW format, those are things we need to figure. SG to track them in a separate issue.
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 261 at r1 (raw file):
Previously, Linchenn wrote…
Referring to bias_add op in Python involved in fusedConv2d, the bias is supposed to be a 1-D tensor aligned with the output channels.
I added asserts to check the shape of the bias to be a 1-D tensor or a scalar (the scalar could be broadcasted to be a 1-D tensor) and added tests to check conv2d computation with bias.
thank you for investigating on this, this is great information.
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 78 at r10 (raw file):
bias.shape.length === 0 || bias.shape[0] === convInfo.outChannels, () => `WebGL conv2dByMatMul bias shape (${bias.shape}) does not ` + `match the number of output channels (${convInfo.outChannels})`);
is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @Linchenn and @qjia7)
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 258 at r10 (raw file):
is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 1 of 1 files at r3, 1 of 2 files at r4, 1 of 1 files at r6, 1 of 3 files at r8, 1 of 1 files at r9, 1 of 1 files at r10, all commit messages.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @Linchenn and @qjia7)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)
a discussion (no related file):
Previously, pyu10055 (Ping Yu) wrote…
thank you for expanding into bias and activation function, to fully support NCHW format, those are things we need to figure. SG to track them in a separate issue.
Thank you Ping!
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 78 at r10 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.
Done.
tfjs-backend-webgl/src/kernels/Conv2D_impl.ts
line 258 at r10 (raw file):
Previously, pyu10055 (Ping Yu) wrote…
is not compatible with the number of output channels.
since the bias can be scalar, which does not match the channel number.
Done.
Based on the findings from #6308 (comment), there are 3 branches to compute Conv2d on WebGL. The first branch and the second branch are problematic for
NCHW
format, so this PR does the following:NCHW
format, so this PR adds a condition for it to exclude computation forNCHW
format. Whether this branch is supposed to supportNCHW
format deserves some further research.NCHW
format, instead of simply reshaping it. This PR fixes it.To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is