Skip to content

Commit

Permalink
[WebGPU] Fix shader key for ScatterProgram (tensorflow#7932)
Browse files Browse the repository at this point in the history
* Add uniforms

* Update webgpu_program.ts

* Update webgpu_program.ts

* Revert "Add uniforms"

This reverts commit 645802f.

* Revert "Update webgpu_program.ts"

This reverts commit 58ea96f.

* Revert "Update webgpu_program.ts"

This reverts commit 32386ac.

* Add key to scatter webgpu program
  • Loading branch information
Linchenn authored Sep 25, 2023
1 parent a7bec12 commit 73b2fd1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tfjs-backend-webgpu/src/scatter_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ export class ScatterProgram implements WebGPUProgram {
this.dispatch =
computeDispatch(this.dispatchLayout, flattenXShape, this.workgroupSize);
this.sliceDimGreaterThanOne = sliceDim > 1;
this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${
this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`;
this.shaderKey =
`scatter_${indicesRank}_${updatesRank}_${this.sliceDimGreaterThanOne}_${
outputDtype}_${sumDupeIndices}_${strides.length}`;
const stridesType = getCoordsDataType(strides.length);
this.uniforms =
`sliceDim : i32, strides: ${stridesType}, updatesSize: i32,`;
Expand Down

0 comments on commit 73b2fd1

Please sign in to comment.