From 8bae974f6a7b05ffeafe975c9486378062bfbb75 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Mon, 21 Feb 2022 11:15:36 +0800 Subject: [PATCH] address comments --- tfjs-backend-webgpu/src/gather_webgpu.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-webgpu/src/gather_webgpu.ts b/tfjs-backend-webgpu/src/gather_webgpu.ts index 761e882d93d..26d0db130e1 100644 --- a/tfjs-backend-webgpu/src/gather_webgpu.ts +++ b/tfjs-backend-webgpu/src/gather_webgpu.ts @@ -37,7 +37,7 @@ export class GatherProgram implements WebGPUProgram { this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - this.shaderKey = `gather_${aShape}`; + this.shaderKey = `gather`; } getUserCode(): string { @@ -47,7 +47,7 @@ export class GatherProgram implements WebGPUProgram { if (index < uniforms.size) { let resRC = getCoordsFromIndex(index); let indexZ = i32(getIndices(resRC.x, resRC.z)); - let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < ${this.aShape[2]}); + let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]); setOutputAtIndex(index, inBounds * getA(${sourceCoords})); } }