-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support NCHW dataFormat for fusedConv2d (Core) #6373
Conversation
…nto nchw-fusedconv2d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @qjia7)
tfjs-backend-webgpu/src/setup_test.ts
line 117 at r16 (raw file):
'backProp', // Conv2DBackpropFilter not yet // implemented 'basic in NCHW', // NCHW format is not yet supported
@qjia7 this PR is trying to enable fusedConv2d op from tfjs-core. So I moves the assert to disable NCHW from core to specific backends and let these backends skip the related tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Lin, I don't see data format check for fused conv2d in WASM and CPU? Has it been disabled already?
Reviewable status:
complete! 1 of 1 approvals obtained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For WASM, fused conv2d has been blocked for NCHW format:
tfjs/tfjs-backend-wasm/src/kernels/FusedConv2D.ts
Lines 128 to 132 in 755bb49
if (dataFormat !== 'NHWC') { | |
throw new Error( | |
`wasm backend FusedConv2D does not support dataFormat:'` + | |
`${dataFormat}'. Please use 'NHWC'.`); | |
} |
For CPU, it currently supports fused conv2d NCHW. I have fixed the bias issue and the PReLU issue.
Thank you!
Reviewable status:
complete! 1 of 1 approvals obtained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 2 of 2 files at r2, 1 of 2 files at r3, 1 of 1 files at r9, 1 of 1 files at r13, 2 of 3 files at r14, 2 of 3 files at r15, 1 of 1 files at r16, all commit messages.
Reviewable status:complete! 2 of 1 approvals obtained
#6368
Currently, some backends already supports fusedConv2d for
NCHW
format (CPU, WebGL and WebGPU), some want to move this assert from tfjs-core to the specific backends that do not support:Specifically, this PR:
1.1 Core: enables inference, while keeps gradient disabled.
1.2 CPU backend: disables the computation with bias or prelu activation, because cpu's fusedConv2d does not transpose the values properly for NCHW.
1.3 WebGPU, Node and WASM backends: disabled
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is