Skip to content

Commit

Permalink
[webgpu] Use numbers directly instead of const variables (#7193)
Browse files Browse the repository at this point in the history
* [webgpu] Use numbers directly instead of `const variables` in shader's global scope

FIXES BUG: #6746
Deno uses Naga for wgsl compilation, but Naga currently uses let for global constants(will be fixed in gfx-rs/naga#1829).
This PR helps WebGPU to run pose-detection models on Deno by removing global constants in shaders.
  • Loading branch information
haoyunfeix authored Dec 21, 2022
1 parent 0d88a30 commit fd3c3f4
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 119 deletions.
11 changes: 6 additions & 5 deletions tfjs-backend-webgpu/src/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
}

getUserCode(): string {
const workgroupSizeX = this.workgroupSize[0];
const getInputShapeLastDim = () => {
if (this.inputShape.length === 1) {
return 'uniforms.xShape';
Expand All @@ -91,8 +92,8 @@ export class ArgMinMaxProgram implements WebGPUProgram {

if (this.type === 'shared') {
const sharedMemorySnippet = `
var<workgroup> xBestIndices : array<i32, ${this.workgroupSize[0]}>;
var<workgroup> xBestValues : array<f32, ${this.workgroupSize[0]}>;
var<workgroup> xBestIndices : array<i32, ${workgroupSizeX}>;
var<workgroup> xBestValues : array<f32, ${workgroupSizeX}>;
`;
const userCode = `
fn DIV_CEIL(a : u32, b : u32) -> u32 {
Expand All @@ -102,14 +103,14 @@ export class ArgMinMaxProgram implements WebGPUProgram {
${sharedMemorySnippet}
${main('index')} {
let outputIndex = index / i32(workgroupSizeX);
let outputIndex = index / ${workgroupSizeX};
let reduceLength = ${getInputShapeLastDim()};
var bestIndex = i32(localId.x);
var bestValue = uniforms.infinityValue;
let outputCoords = getCoordsFromIndex(outputIndex);
for (var k = i32(localId.x); k < reduceLength && outputIndex < uniforms.size;
k = k + i32(workgroupSizeX)) {
k = k + ${workgroupSizeX}) {
let candidate = getX(${splitOutputCoords()} k);
if (!isnan(candidate) && candidate ${this.op} bestValue) {
bestValue = candidate;
Expand All @@ -120,7 +121,7 @@ export class ArgMinMaxProgram implements WebGPUProgram {
xBestIndices[localId.x] = bestIndex;
workgroupBarrier();
var reduceSize = min(u32(reduceLength), workgroupSizeX);
var reduceSize = min(u32(reduceLength), ${workgroupSizeX}u);
for (var currentSize = reduceSize / 2u; reduceSize > 1u;
currentSize = reduceSize / 2u) {
let interval = DIV_CEIL(reduceSize, 2u);
Expand Down
8 changes: 4 additions & 4 deletions tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
getUserCode(): string {
const xNumber = (this.workPerThread - 1) * this.convInfo.strideWidth +
this.convInfo.filterWidth;
const strideHeight = this.convInfo.strideHeight;
const strideWidth = this.convInfo.strideWidth;

const userCode = `
${activationFnSnippet(this.activation, this.hasPreluActivation, true, 4)}
Expand All @@ -80,14 +82,12 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
return value;
}
const strideHeight = ${this.convInfo.strideHeight};
const strideWidth = ${this.convInfo.strideWidth};
${main()} {
let batch = i32(globalId.z) / uniforms.outShape[1];
let r = i32(globalId.z) % uniforms.outShape[1];
let c = i32(globalId.y) * ${this.workPerThread};
let d1 = i32(globalId.x) * 4;
let xRCCorner = vec2<i32>(r, c) * vec2<i32>(strideHeight, strideWidth) - uniforms.pad;
let xRCCorner = vec2<i32>(r, c) * vec2<i32>(${strideHeight}, ${strideWidth}) - uniforms.pad;
let xRCorner = xRCCorner.x;
let xCCorner = xRCCorner.y;
Expand All @@ -107,7 +107,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram {
for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) {
let wValue = getW(wR, wC, d1, 0);
for (var i = 0; i < ${this.workPerThread}; i++) {
dotProd[i] = fma(xVals[i * strideWidth + wC], wValue, dotProd[i]);
dotProd[i] = fma(xVals[i * ${strideWidth} + wC], wValue, dotProd[i]);
}
}
}
Expand Down
Loading

0 comments on commit fd3c3f4

Please sign in to comment.