Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haoyunfeix committed Feb 21, 2022
1 parent a8809b8 commit 8bae974
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tfjs-backend-webgpu/src/gather_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export class GatherProgram implements WebGPUProgram {
this.dispatchLayout = flatDispatchLayout(this.outputShape);
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize);
this.shaderKey = `gather_${aShape}`;
this.shaderKey = `gather`;
}

getUserCode(): string {
Expand All @@ -47,7 +47,7 @@ export class GatherProgram implements WebGPUProgram {
if (index < uniforms.size) {
let resRC = getCoordsFromIndex(index);
let indexZ = i32(getIndices(resRC.x, resRC.z));
let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < ${this.aShape[2]});
let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]);
setOutputAtIndex(index, inBounds * getA(${sourceCoords}));
}
}
Expand Down

0 comments on commit 8bae974

Please sign in to comment.