diff --git a/tfjs-backend-webgpu/src/scatter_webgpu.ts b/tfjs-backend-webgpu/src/scatter_webgpu.ts index e866424a33a..22402030387 100644 --- a/tfjs-backend-webgpu/src/scatter_webgpu.ts +++ b/tfjs-backend-webgpu/src/scatter_webgpu.ts @@ -48,8 +48,9 @@ export class ScatterProgram implements WebGPUProgram { this.dispatch = computeDispatch(this.dispatchLayout, flattenXShape, this.workgroupSize); this.sliceDimGreaterThanOne = sliceDim > 1; - this.shaderKey = `scatter_${indicesRank}_${updatesRank}_${ - this.sliceDimGreaterThanOne}_${outputDtype}_${sumDupeIndices}`; + this.shaderKey = + `scatter_${indicesRank}_${updatesRank}_${this.sliceDimGreaterThanOne}_${ + outputDtype}_${sumDupeIndices}_${strides.length}`; const stridesType = getCoordsDataType(strides.length); this.uniforms = `sliceDim : i32, strides: ${stridesType}, updatesSize: i32,`;