diff --git a/tfjs-backend-webgpu/src/argminmax_webgpu.ts b/tfjs-backend-webgpu/src/argminmax_webgpu.ts index 187822e0f06..39e33c7fe1c 100644 --- a/tfjs-backend-webgpu/src/argminmax_webgpu.ts +++ b/tfjs-backend-webgpu/src/argminmax_webgpu.ts @@ -67,6 +67,7 @@ export class ArgMinMaxProgram implements WebGPUProgram { } getUserCode(): string { + const workgroupSizeX = this.workgroupSize[0]; const getInputShapeLastDim = () => { if (this.inputShape.length === 1) { return 'uniforms.xShape'; @@ -91,8 +92,8 @@ export class ArgMinMaxProgram implements WebGPUProgram { if (this.type === 'shared') { const sharedMemorySnippet = ` - var xBestIndices : array; - var xBestValues : array; + var xBestIndices : array; + var xBestValues : array; `; const userCode = ` fn DIV_CEIL(a : u32, b : u32) -> u32 { @@ -102,14 +103,14 @@ export class ArgMinMaxProgram implements WebGPUProgram { ${sharedMemorySnippet} ${main('index')} { - let outputIndex = index / i32(workgroupSizeX); + let outputIndex = index / ${workgroupSizeX}; let reduceLength = ${getInputShapeLastDim()}; var bestIndex = i32(localId.x); var bestValue = uniforms.infinityValue; let outputCoords = getCoordsFromIndex(outputIndex); for (var k = i32(localId.x); k < reduceLength && outputIndex < uniforms.size; - k = k + i32(workgroupSizeX)) { + k = k + ${workgroupSizeX}) { let candidate = getX(${splitOutputCoords()} k); if (!isnan(candidate) && candidate ${this.op} bestValue) { bestValue = candidate; @@ -120,7 +121,7 @@ export class ArgMinMaxProgram implements WebGPUProgram { xBestIndices[localId.x] = bestIndex; workgroupBarrier(); - var reduceSize = min(u32(reduceLength), workgroupSizeX); + var reduceSize = min(u32(reduceLength), ${workgroupSizeX}u); for (var currentSize = reduceSize / 2u; reduceSize > 1u; currentSize = reduceSize / 2u) { let interval = DIV_CEIL(reduceSize, 2u); diff --git a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts index 35786b99e46..c92084b3322 100644 --- a/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts +++ b/tfjs-backend-webgpu/src/depthwise_conv2d_vec4_webgpu.ts @@ -69,6 +69,8 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { getUserCode(): string { const xNumber = (this.workPerThread - 1) * this.convInfo.strideWidth + this.convInfo.filterWidth; + const strideHeight = this.convInfo.strideHeight; + const strideWidth = this.convInfo.strideWidth; const userCode = ` ${activationFnSnippet(this.activation, this.hasPreluActivation, true, 4)} @@ -80,14 +82,12 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { return value; } - const strideHeight = ${this.convInfo.strideHeight}; - const strideWidth = ${this.convInfo.strideWidth}; ${main()} { let batch = i32(globalId.z) / uniforms.outShape[1]; let r = i32(globalId.z) % uniforms.outShape[1]; let c = i32(globalId.y) * ${this.workPerThread}; let d1 = i32(globalId.x) * 4; - let xRCCorner = vec2(r, c) * vec2(strideHeight, strideWidth) - uniforms.pad; + let xRCCorner = vec2(r, c) * vec2(${strideHeight}, ${strideWidth}) - uniforms.pad; let xRCorner = xRCCorner.x; let xCCorner = xRCCorner.y; @@ -107,7 +107,7 @@ export class DepthwiseConv2DVec4Program implements WebGPUProgram { for (var wC = 0; wC < ${this.convInfo.filterWidth}; wC = wC + 1) { let wValue = getW(wR, wC, d1, 0); for (var i = 0; i < ${this.workPerThread}; i++) { - dotProd[i] = fma(xVals[i * strideWidth + wC], wValue, dotProd[i]); + dotProd[i] = fma(xVals[i * ${strideWidth} + wC], wValue, dotProd[i]); } } } diff --git a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts index b2be27d81b9..6bd029cd50b 100644 --- a/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_packed_webgpu.ts @@ -89,35 +89,36 @@ export function matMulReadWriteFnSource( `; } -const writeDataToSubAVec4Snippet = (transpose: boolean) => { - if (transpose) { - return ` +const writeDataToSubAVec4Snippet = + (transpose: boolean, innerElementSize: number) => { + if (transpose) { + return ` mm_Asub[inputRow][inputCol] = mm_readA(batchA, kStart + inputRow, - globalRowStart / innerElementSize + inputCol); + globalRowStart / ${innerElementSize} + inputCol); `; - } else { - return ` + } else { + return ` mm_Asub[inputRow][inputCol] = mm_readA(batchA, globalRow + innerRow, - kStart / innerElementSize + inputCol); + kStart / ${innerElementSize} + inputCol); `; - } -}; + } + }; const calculateResultSnippet = - (transposeA: boolean, innerElementSize: number) => { + (transposeA: boolean, innerElementSize: number, rowPerThread: number) => { if (transposeA) { return ` - let ACached0 = mm_Asub[k * innerElementSize][localRow]; - let ACached1 = mm_Asub[k * innerElementSize + 1][localRow]; - let ACached2 = mm_Asub[k * innerElementSize + 2][localRow]; + let ACached0 = mm_Asub[k * ${innerElementSize}][localRow]; + let ACached1 = mm_Asub[k * ${innerElementSize} + 1][localRow]; + let ACached2 = mm_Asub[k * ${innerElementSize} + 2][localRow]; ${ - innerElementSize === 3 ? - '' : - 'let ACached3 = mm_Asub[k * innerElementSize + 3][localRow];'} - for (var i = 0; i < rowPerThread; i = i + 1) { + innerElementSize === 3 ? '' : + `let ACached3 = mm_Asub[k * ${ + innerElementSize} + 3][localRow];`} + for (var i = 0; i < ${rowPerThread}; i++) { acc[i] = BCached0 * ACached0[i] + acc[i]; acc[i] = BCached1 * ACached1[i] + acc[i]; acc[i] = BCached2 * ACached2[i] + acc[i]; @@ -128,7 +129,7 @@ const calculateResultSnippet = }`; } else { return ` - for (var i = 0; i < rowPerThread; i = i + 1) { + for (var i = 0; i < ${rowPerThread}; i++) { let ACached = mm_Asub[tileRow + i][k]; acc[i] = BCached0 * ACached.x + acc[i]; acc[i] = BCached1 * ACached.y + acc[i]; @@ -150,6 +151,7 @@ export function makeMatMulPackedVec4Source( const tileAHight = transposeA ? tileInner : tileAOuter; const innerElementSize = tileAWidth / workgroupSize[0]; const rowPerThreadB = tileInner / workgroupSize[1]; + const rowPerThread = workPerThread[1]; util.assert( ((transposeA && innerElementSize === 4 && workPerThread[1] === 4) || (!transposeA && (innerElementSize === 3 || innerElementSize === 4))) && @@ -168,17 +170,12 @@ export function makeMatMulPackedVec4Source( var mm_Bsub : array, ${ tileBOuter / workPerThread[0]}>, ${tileInner}>; - const rowPerThread = ${workPerThread[1]}; - const colPerThread = ${workPerThread[0]}; - const innerElementSize = ${innerElementSize}; - const tileInner = ${tileInner}; - ${main()} { let localRow = i32(localId.y); - let tileRow = ${isVectorA ? '0' : 'localRow * rowPerThread'}; + let tileRow = ${isVectorA ? '0' : `localRow * ${rowPerThread}`}; let tileCol = i32(localId.x); - let globalRow = ${isVectorA ? '0' : 'i32(globalId.y) * rowPerThread'}; + let globalRow = ${isVectorA ? '0' : `i32(globalId.y) * ${rowPerThread}`}; let globalCol = i32(globalId.x); let batch = ${splitK ? '0' : 'i32(globalId.z)'}; let batchA = ${ @@ -189,48 +186,48 @@ export function makeMatMulPackedVec4Source( let numTiles = ${ splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : - '(uniforms.dimInner - 1) / tileInner + 1'}; + `(uniforms.dimInner - 1) / ${tileInner} + 1`}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, ${rowPerThread}>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < numTiles; t++) { // Load one tile of A into local memory. - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { let inputRow = tileRow + innerRow; let inputCol = tileCol; - ${writeDataToSubAVec4Snippet(transposeA)} + ${writeDataToSubAVec4Snippet(transposeA, innerElementSize)} } // Load one tile of B into local memory. - for (var innerRow = 0; innerRow < ${ - rowPerThreadB}; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThreadB}; innerRow++) { let inputRow = tileRowB + innerRow; let inputCol = tileCol; mm_Bsub[inputRow][inputCol] = mm_readB(batchB, kStart + inputRow, globalCol); } - kStart = kStart + tileInner; + kStart = kStart + ${tileInner}; workgroupBarrier(); // Compute acc values for a single thread. - for (var k = 0; k < tileInner / innerElementSize; k = k + 1) { - let BCached0 = mm_Bsub[k * innerElementSize][tileCol]; - let BCached1 = mm_Bsub[k * innerElementSize + 1][tileCol]; - let BCached2 = mm_Bsub[k * innerElementSize + 2][tileCol]; + for (var k = 0; k < ${tileInner / innerElementSize}; k++) { + let BCached0 = mm_Bsub[k * ${innerElementSize}][tileCol]; + let BCached1 = mm_Bsub[k * ${innerElementSize} + 1][tileCol]; + let BCached2 = mm_Bsub[k * ${innerElementSize} + 2][tileCol]; ${ innerElementSize === 3 ? '' : - 'let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];'} + `let BCached3 = mm_Bsub[k * ${innerElementSize} + 3][tileCol];`} - ${calculateResultSnippet(transposeA, innerElementSize)} + ${ + calculateResultSnippet(transposeA, innerElementSize, rowPerThread)} } workgroupBarrier(); } - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]); } }`; @@ -281,6 +278,8 @@ export function makeMatMulPackedSource( const rowPerThreadA = tileAHight / workgroupSize[1]; const colPerThreadA = tileAWidth / workgroupSize[0]; const rowPerThreadB = tileInner / workgroupSize[1]; + const rowPerThread = workPerThread[1]; + const colPerThread = workPerThread[0]; const matmulSnippet = sequentialAccessByThreads ? ` let localRow = i32(localId.y); @@ -289,7 +288,7 @@ export function makeMatMulPackedSource( let globalColStart = i32(workgroupId.x) * ${tileBOuter}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < numTiles; t++) { // Load one tile of A into local memory. for (var inputRow = localRow; inputRow < ${ tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { @@ -308,21 +307,22 @@ export function makeMatMulPackedSource( globalColStart + inputCol); } } - kStart = kStart + tileInner; + kStart = kStart + ${tileInner}; workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; - for (var k = 0; k < tileInner; k = k + 1) { - for (var inner = 0; inner < colPerThread; inner = inner + 1) { - BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; + var BCached : array; + for (var k = 0; k < ${tileInner}; k++) { + for (var inner = 0; inner < ${colPerThread}; inner++) { + BCached[inner] = mm_Bsub[k][localCol + inner * ${ + workgroupSize[0]}]; } - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { let ACached = ${ transposeA ? `mm_Asub[k][localRow + innerRow * ${workgroupSize[1]}];` : `mm_Asub[localRow + innerRow * ${workgroupSize[1]}][k];`} - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; } @@ -330,32 +330,31 @@ export function makeMatMulPackedSource( } workgroupBarrier(); } - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { let gRow = globalRowStart + localRow + innerRow * ${workgroupSize[1]}; - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { - let gCol = globalColStart + localCol + innerCol * ${workgroupSize[0]}; + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { + let gCol = globalColStart + localCol + innerCol * ${ + workgroupSize[0]}; mm_write(batch, gRow, gCol, acc[innerRow][innerCol]); } } ` : ` - let tileRow = i32(localId.y) * rowPerThread; - let tileCol = i32(localId.x) * colPerThread; + let tileRow = i32(localId.y) * ${rowPerThread}; + let tileCol = i32(localId.x) * ${colPerThread}; - let globalRow = i32(globalId.y) * rowPerThread; - let globalCol = i32(globalId.x) * colPerThread; + let globalRow = i32(globalId.y) * ${rowPerThread}; + let globalCol = i32(globalId.x) * ${colPerThread}; let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; let tileRowA = i32(localId.y) * ${rowPerThreadA}; let tileColA = i32(localId.x) * ${colPerThreadA}; let tileRowB = i32(localId.y) * ${rowPerThreadB}; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < numTiles; t++) { // Load one tile of A into local memory. - for (var innerRow = 0; innerRow < ${ - rowPerThreadA}; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < ${ - colPerThreadA}; innerCol = innerCol + 1) { + for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow++) { + for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol++) { let inputRow = tileRowA + innerRow; let inputCol = tileColA + innerCol; ${writeDataToSubASnippet(transposeA)} @@ -363,9 +362,8 @@ export function makeMatMulPackedSource( } // Load one tile of B into local memory. - for (var innerRow = 0; innerRow < ${ - rowPerThreadB}; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + for (var innerRow = 0; innerRow < ${rowPerThreadB}; innerRow++) { + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { let inputRow = tileRowB + innerRow; let inputCol = tileCol + innerCol; mm_Bsub[inputRow][inputCol] = mm_readB(batchB, @@ -373,19 +371,19 @@ export function makeMatMulPackedSource( globalCol + innerCol); } } - kStart = kStart + tileInner; + kStart = kStart + ${tileInner}; workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; - for (var k = 0; k < tileInner; k = k + 1) { - for (var inner = 0; inner < colPerThread; inner = inner + 1) { + var BCached : array; + for (var k = 0; k < ${tileInner}; k++) { + for (var inner = 0; inner < ${colPerThread}; inner++) { BCached[inner] = mm_Bsub[k][tileCol + inner]; } - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { ${readDataFromSubASnippet(transposeA)} - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol]; } } @@ -394,8 +392,8 @@ export function makeMatMulPackedSource( workgroupBarrier(); } - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]); } @@ -405,9 +403,6 @@ export function makeMatMulPackedSource( return ` var mm_Asub : array, ${tileAHight}>; var mm_Bsub : array, ${tileInner}>; - const rowPerThread = ${workPerThread[1]}; - const colPerThread = ${workPerThread[0]}; - const tileInner = ${tileInner}; ${main()} { let batch = ${splitK ? '0' : 'i32(globalId.z)'}; @@ -417,14 +412,14 @@ export function makeMatMulPackedSource( splitK || !broadcastBatch ? 'batch' : 'batch % uniforms.bShape[0]'}; let numTiles = ${ splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : - '(uniforms.dimInner - 1) / tileInner + 1'}; + `(uniforms.dimInner - 1) / ${tileInner} + 1`}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, ${rowPerThread}>; // Without this initialization strange values show up in acc. - for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { - for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) { + for (var innerRow = 0; innerRow < ${rowPerThread}; innerRow++) { + for (var innerCol = 0; innerCol < ${colPerThread}; innerCol++) { acc[innerRow][innerCol] = 0.0; } } @@ -453,8 +448,8 @@ export function makeVectorMatrixProductSource( util.assert( workgroupSize[1] === 1 && workgroupSize[2] === 1, () => `A linear work group size is required. But got ${workgroupSize}.`); + const tileSize = workgroupSize[0] * 4; return ` - const tileSize = ${workgroupSize[0] * 4}; var mm_Asub : array, ${workgroupSize[0]}>; ${main()} { @@ -462,7 +457,7 @@ export function makeVectorMatrixProductSource( let globalCol = i32(globalId.x); let globalRow = i32(globalId.y); - let numTiles = (uniforms.dimInner - 1) / tileSize + 1; + let numTiles = (uniforms.dimInner - 1) / ${tileSize} + 1; let batch = i32(globalId.z); let batchA = batch % uniforms.aShape[0]; let batchB = batch % uniforms.bShape[0]; @@ -470,15 +465,15 @@ export function makeVectorMatrixProductSource( var acc = 0.0; // Loop over shared dimension. - for (var t = 0; t < numTiles; t = t + 1) { + for (var t = 0; t < numTiles; t++) { // Load one tile of A into local memory. - let colA = t * tileSize + tileCol * 4; + let colA = t * ${tileSize} + tileCol * 4; mm_Asub[tileCol] = vec4(${readVectorASnippet(transposeA)}); workgroupBarrier(); // Compute acc values for a single thread. - for (var k = 0; k < tileSize / 4; k = k + 1) { - let rowB = t * tileSize + k * 4; + for (var k = 0; k < ${tileSize / 4}; k++) { + let rowB = t * ${tileSize} + k * 4; let BCached = vec4(mm_readB(batchB, rowB, globalCol), mm_readB(batchB, rowB + 1, globalCol), mm_readB(batchB, rowB + 2, globalCol), diff --git a/tfjs-backend-webgpu/src/matmul_reduce_webgpu.ts b/tfjs-backend-webgpu/src/matmul_reduce_webgpu.ts index e63adcae323..7392ef353fe 100644 --- a/tfjs-backend-webgpu/src/matmul_reduce_webgpu.ts +++ b/tfjs-backend-webgpu/src/matmul_reduce_webgpu.ts @@ -22,9 +22,9 @@ import {matMulReadWriteFnSource} from './matmul_packed_webgpu'; import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; import {computeDispatch} from './webgpu_util'; -export function makeMatMulReduceSource(): string { +export function makeMatMulReduceSource(workgroupSizeX: number): string { return ` - var sumValues : array; + var sumValues : array; ${main()} { let coords = getOutputCoords(); let batch = coords[0]; @@ -34,7 +34,7 @@ export function makeMatMulReduceSource(): string { let col = coords[2]; var sum = 0.0; let Length = uniforms.dimInner; - for (var k = i32(localId.x); k < Length; k = k + i32(workgroupSizeX)) { + for (var k = i32(localId.x); k < Length; k = k + ${workgroupSizeX}) { let dataA = mm_readA(batchA, row, k); let dataB = mm_readB(batchB, k, col); sum = sum + dataA * dataB; @@ -42,7 +42,7 @@ export function makeMatMulReduceSource(): string { sumValues[localId.x] = sum; workgroupBarrier(); - for(var currentSize = workgroupSizeX / 2u; currentSize > 1u; + for(var currentSize = ${workgroupSizeX / 2}u; currentSize > 1u; currentSize = currentSize / 2u) { if (localId.x < currentSize) { @@ -108,7 +108,7 @@ export class MatMulReduceProgram implements WebGPUProgram { ${ matMulReadWriteFnSource( this.addBias, this.activation, this.transposeA, this.transposeB)} - ${makeMatMulReduceSource()} + ${makeMatMulReduceSource(this.workgroupSize[0])} `; return userCode; } diff --git a/tfjs-backend-webgpu/src/reduce_webgpu.ts b/tfjs-backend-webgpu/src/reduce_webgpu.ts index d45b9098710..0acd26f2586 100644 --- a/tfjs-backend-webgpu/src/reduce_webgpu.ts +++ b/tfjs-backend-webgpu/src/reduce_webgpu.ts @@ -52,6 +52,7 @@ export class ReduceProgram implements WebGPUProgram { getUserCode(): string { let reduceOp = ``; let initValue = '0.0'; + const workgroupSizeX = this.workgroupSize[0]; if (this.reduceType === 'min' || this.reduceType === 'max') { reduceOp = ` if (isnan(candidate)) { @@ -79,7 +80,7 @@ export class ReduceProgram implements WebGPUProgram { `setOutputAtIndex(outputIndex, bestValue);`; const sharedMemorySnippet = ` - var xBestValues : array; + var xBestValues : array; `; const userCode = ` @@ -97,20 +98,20 @@ export class ReduceProgram implements WebGPUProgram { return offset; } ${main('index')} { - let outputIndex = index / i32(workgroupSizeX); + let outputIndex = index / ${workgroupSizeX}; let offset = getOffset(outputIndex); var bestValue = ${initValue}; let Length = uniforms.reduceSize; - let WorkPerThread = DIV_CEIL(u32(Length), workgroupSizeX); + let WorkPerThread = DIV_CEIL(u32(Length), ${workgroupSizeX}u); for (var k = i32(localId.x); k < Length && outputIndex < uniforms.size; - k = k + i32(workgroupSizeX)) { + k = k + ${workgroupSizeX}) { let candidate = f32(x[offset + k]); ${reduceOp} } xBestValues[localId.x] = bestValue; workgroupBarrier(); - var reduceSize = min(u32(Length), workgroupSizeX); + var reduceSize = min(u32(Length), ${workgroupSizeX}u); for (var currentSize = reduceSize / 2u; reduceSize > 1u; currentSize = reduceSize / 2u) { let interval = DIV_CEIL(reduceSize, 2u); diff --git a/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts b/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts index fa3492d6bca..26a92ad3af0 100644 --- a/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts +++ b/tfjs-backend-webgpu/src/transpose_shared_webgpu.ts @@ -46,13 +46,13 @@ export class TransposeSharedProgram implements WebGPUProgram { this.workgroupSize[0] === this.workgroupSize[1], () => `Must be a square tile, current tile shape is ${ this.workgroupSize[0]} x ${this.workgroupSize[1]}`); + const tileSize = this.workgroupSize[0]; const userCode = ` - const tileSize = ${this.workgroupSize[0]}; var tile : array, ${ this.workgroupSize[0]}>; ${main()} { - var x = i32(workgroupId.x) * tileSize + i32(localId.x); - var y = i32(workgroupId.y) * tileSize + i32(localId.y); + var x = i32(workgroupId.x) * ${tileSize} + i32(localId.x); + var y = i32(workgroupId.y) * ${tileSize} + i32(localId.y); let width = uniforms.outShape[0]; let height = uniforms.outShape[1]; if (x < width && y < height) { @@ -60,8 +60,8 @@ export class TransposeSharedProgram implements WebGPUProgram { } workgroupBarrier(); - x = i32(workgroupId.y) * tileSize + i32(localId.x); - y = i32(workgroupId.x) * tileSize + i32(localId.y); + x = i32(workgroupId.y) * ${tileSize} + i32(localId.x); + y = i32(workgroupId.x) * ${tileSize} + i32(localId.y); if (x < height && y < width) { setOutputAtIndex((y * height + x), tile[localId.x] [localId.y]); diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index 02a391ff30a..4d093fa56bd 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -124,10 +124,11 @@ export function getMainHeaderString(...params: string[]): string { return snippet; } -export function getStartHeaderString(useGlobalIndex: boolean): string { +export function getStartHeaderString( + useGlobalIndex: boolean, program: WebGPUProgram): string { let snippet: string; snippet = ` - ${getWorkgroupSizeString()} + ${getWorkgroupSizeString(program)} fn _start(@builtin(local_invocation_id) LocalId : vec3, @builtin(global_invocation_id) GlobalId : vec3, @builtin(local_invocation_index) LocalIndex: u32, @@ -144,9 +145,10 @@ export function getStartHeaderString(useGlobalIndex: boolean): string { return snippet; } -export function getWorkgroupSizeString(): string { +export function getWorkgroupSizeString(program: WebGPUProgram): string { return ` - @compute @workgroup_size(workgroupSizeX, workgroupSizeY, workgroupSizeZ) + @compute @workgroup_size(${program.workgroupSize[0]}, ${ + program.workgroupSize[1]}, ${program.workgroupSize[2]}) `; } @@ -157,9 +159,6 @@ function makeShader( const flatWorkgroupSize = program.workgroupSize[0] * program.workgroupSize[1] * program.workgroupSize[2]; prefixSnippets.push(` - const workgroupSizeX = ${program.workgroupSize[0]}u; - const workgroupSizeY = ${program.workgroupSize[1]}u; - const workgroupSizeZ = ${program.workgroupSize[2]}u; var localId: vec3; var localIndex: u32; @@ -174,7 +173,7 @@ function makeShader( ` return i32(globalId.x);` : ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y + workgroupId.y * numWorkgroups.x + workgroupId.x) * ${ - flatWorkgroupSize} + + flatWorkgroupSize}u + localIndex); `} } @@ -198,7 +197,7 @@ function makeShader( prefixSnippets.join('\n'), getCoordsFromIndexSnippet(outputData.shape), program.getUserCode(), - getStartHeaderString(useGlobalIndex), + getStartHeaderString(useGlobalIndex, program), ].join('\n'); } @@ -258,7 +257,7 @@ function makeShader( getOutputCoordsSnippet(outputData.shape, program.dispatchLayout); const sources = [ - commonSnippet + isInfSnippet, prefixSnippets.join('\n'), + commonSnippet, prefixSnippets.join('\n') + isInfSnippet, getCoordsFromIndexSnippet(outputData.shape), coordsSnippet, getOutputIndexFromCoordsSnippet(outputData.shape.length) ]; @@ -280,7 +279,7 @@ function makeShader( sources.push(inputSnippet); sources.push(program.getUserCode()); const useGlobalIndex = isFlatDispatchLayout(program); - sources.push(getStartHeaderString(useGlobalIndex)); + sources.push(getStartHeaderString(useGlobalIndex, program)); const source = sources.join('\n'); return source; }