-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Support convTranspose vec4 #6603
Conversation
7660b60
to
1a7c882
Compare
@qjia7 @xhcao @gyagp @haoyunfeix PTAL |
return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];'; | ||
case 4: | ||
return ` | ||
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to check whether col + 1|col + 2|col + 3
out of boundary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this necessary, but in the updated commit, I do add this check, PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it seem that it's not needed since you have make sure that convInfo.inChannels % 4 === 0
|
||
constructor(convInfo: backend_util.Conv2DInfo) { | ||
this.outputShape = convInfo.inShape; | ||
|
||
util.assert( | ||
convInfo.dataFormat === 'channelsLast', | ||
() => 'TODO: NCHW is unimplemented'); | ||
this.isChannelsLast = convInfo.dataFormat === 'channelsLast'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Currently NCHW
is unimplemented. So this.isChannelsLast
can be ignored in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
this.isChannelsLast = convInfo.dataFormat === 'channelsLast'; | ||
this.isVec4 = (((convInfo.inChannels % 4 === 0) && this.isChannelsLast) || | ||
(convInfo.outWidth % 4 === 0 && !this.isChannelsLast)) && | ||
convInfo.outChannels % 4 === 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const dimAOuter = convInfo.inShape[1] * convInfo.inShape[2];
const dimBOuter = convInfo.inShape[3];
const dimInner =
convInfo.filterHeight * convInfo.filterWidth * convInfo.outChannels;
So I suppose the vec4
condition should be convInfo.inChannels % 4 === 0 && convInfo.outChannels % 4 === 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
} | ||
} | ||
|
||
${conv2dTransposeCommonSnippet(elementsSize[2])} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
${conv2dTransposeCommonSnippet(this.isVec4? 4 : 1)}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
1a7c882
to
8e399a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];'; | ||
case 4: | ||
return ` | ||
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it seem that it's not needed since you have make sure that convInfo.inChannels % 4 === 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can you also paste the perf gain for this optimization?
Done |
PERF
hand_detector model on intel CFL:
hand_detector model on intel TGL:
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change isdata:image/s3,"s3://crabby-images/d0bb7/d0bb7f7625ca5bf5c3cf7a2b7a514cf841ab8395" alt="Reviewable"