Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Support convTranspose vec4 #6603

Merged
merged 3 commits into from
Jul 7, 2022

Conversation

axinging
Copy link
Contributor

@axinging axinging commented Jul 5, 2022

PERF

hand_detector model on intel CFL:

float vec4
Subsequent average(50 runs) 55.1 ms 34.0 ms
Best time 52.8 ms 32.3 ms

hand_detector model on intel TGL:

float vec4
Subsequent average(50 runs) 20.3 ms 19.5 ms
Best time 18.8 ms 18.4 ms

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@axinging axinging force-pushed the conv2d_backprop_vec4 branch 4 times, most recently from 7660b60 to 1a7c882 Compare July 5, 2022 06:54
@axinging axinging marked this pull request as ready for review July 6, 2022 00:54
@axinging
Copy link
Contributor Author

axinging commented Jul 6, 2022

@qjia7 @xhcao @gyagp @haoyunfeix PTAL

return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to check whether col + 1|col + 2|col + 3 out of boundary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this necessary, but in the updated commit, I do add this check, PTAL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it seem that it's not needed since you have make sure that convInfo.inChannels % 4 === 0


constructor(convInfo: backend_util.Conv2DInfo) {
this.outputShape = convInfo.inShape;

util.assert(
convInfo.dataFormat === 'channelsLast',
() => 'TODO: NCHW is unimplemented');
this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Currently NCHW is unimplemented. So this.isChannelsLast can be ignored in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

this.isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.isVec4 = (((convInfo.inChannels % 4 === 0) && this.isChannelsLast) ||
(convInfo.outWidth % 4 === 0 && !this.isChannelsLast)) &&
convInfo.outChannels % 4 === 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    const dimAOuter = convInfo.inShape[1] * convInfo.inShape[2];
    const dimBOuter = convInfo.inShape[3];
    const dimInner =
        convInfo.filterHeight * convInfo.filterWidth * convInfo.outChannels;

So I suppose the vec4 condition should be convInfo.inChannels % 4 === 0 && convInfo.outChannels % 4 === 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

${conv2dTransposeCommonSnippet(elementsSize[2])}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

${conv2dTransposeCommonSnippet(this.isVec4? 4 : 1)} ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@axinging axinging force-pushed the conv2d_backprop_vec4 branch from 1a7c882 to 8e399a0 Compare July 7, 2022 06:36
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

return 'return W[getIndexFromCoords4D(coord, uniforms.wShape)];';
case 4:
return `
let coord1 = vec4<i32>(coordX, coordY, col + 1, rowInner);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it seem that it's not needed since you have make sure that convInfo.inChannels % 4 === 0

Copy link

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Can you also paste the perf gain for this optimization?

@axinging
Copy link
Contributor Author

axinging commented Jul 7, 2022

LGTM. Can you also paste the perf gain for this optimization?

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants