Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Fix shader key for ScatterProgram #7932

Merged
merged 10 commits into from
Sep 25, 2023

Conversation

Linchenn
Copy link
Collaborator

@Linchenn Linchenn commented Aug 23, 2023

Issue

This PR fixes the following error because a TensorScatterUpdate op is trying to reuse a cached pipeline but it should not, because the size of their uniforms (the strides, the length of which is dynamic) are different.
image

Solution

We could also add such uniform information into TensorScatterUpdate kernel's shader key instead of all kernel's shader key.

Alternatives considered

We could also add uniform informations into shader key to fix it, but it's unnecessary to apply it to all shader keys since this uniform problem only happens, from my perspective, for scatter program.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.

@Linchenn Linchenn requested review from qjia7 and pyu10055 August 23, 2023 23:06
@qjia7
Copy link
Contributor

qjia7 commented Aug 24, 2023

In theory, we could also add such uniform information into TensorScatterUpdate kernel's shader key instead of all kernel's shader key. However, the uniformData is passed directly into runWebGPUProgram, and is not passed to ScatterProgram constructor, so TensorScatterUpdate kernel's shader key could not know the size of uniforms at runtime and thus could not add such uniform information.

Does it work to change it as below?

this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
        this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;

to

this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
        this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}_${strides.length}`;

@Linchenn
Copy link
Collaborator Author

Does it work to change it as below?

Nice catch! It works and just updated the PR.

@Linchenn Linchenn changed the title [WebGPU] Add uniform info into shader key [WebGPU] Fix shader key for ScatterProgram Sep 19, 2023
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@pyu10055 pyu10055 enabled auto-merge (squash) September 25, 2023 22:41
@pyu10055 pyu10055 merged commit 73b2fd1 into tensorflow:master Sep 25, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants