Skip to content

Commit

Permalink
webgpu: Add non-shared argminmax program (#6778)
Browse files Browse the repository at this point in the history
* webgpu: Add non-shared argminmax program

The perf of ArgMax[1, 1025, 2049, 19] in cityscapes architecture in
DeepLabV3 is very poor. With this changes, this op becomes 6.3ms from
22.36ms.

* Add annotation

Co-authored-by: Linchenn <40653845+Linchenn@users.noreply.github.com>
  • Loading branch information
qjia7 and Linchenn authored Aug 23, 2022
1 parent e2b6a9d commit 25cd4f4
Showing 1 changed file with 49 additions and 18 deletions.
67 changes: 49 additions & 18 deletions tfjs-backend-webgpu/src/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {backend_util} from '@tensorflow/tfjs-core';
import {backend_util, util} from '@tensorflow/tfjs-core';
import {getCoordsXYZ, getMainHeaderString as main, WebGPUProgram} from './webgpu_program';
import {computeDispatch, flatDispatchLayout} from './webgpu_util';

Expand All @@ -31,37 +31,42 @@ export class ArgMinMaxProgram implements WebGPUProgram {
reductionFactor: number;
op: string;
size = true;
private type: string;

constructor(inputShape: number[], axis: number, reduceType: 'min'|'max') {
const axes = [axis];
backend_util.assertAxesAreInnerMostDims(
'arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes,
inputShape.length);

this.op = reduceType === 'min' ? '<' : '>';

// |outShape| is the shape with the removed axis
const [outputShape] =
const [outputShape, reduceShape] =
backend_util.computeOutAndReduceShapes(inputShape, axes);

this.outputShape = outputShape.length === 0 ? [1] : outputShape;

this.dispatchLayout = flatDispatchLayout(this.outputShape);
// A work group only outputs a data, so we transfer [1, 1, 1] to compute
// dispatch size.
this.dispatch =
computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]);
// The shared algorithm is mainly used for large reduce size. It fully
// utilizes the threads in one workgroup to do the reduction. However,
// when the reduce size is very small or the output shape is too large. It's
// better to use the plain algorithm to reduce the number of workgroups to
// speedup. The threthold can be further tuned.
if (util.sizeFromShape(reduceShape) < 32 ||
util.sizeFromShape(outputShape) > 1000) {
this.type = 'plain';
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workGroupSize);
} else {
this.type = 'shared';
// A work group only outputs a data, so we transfer [1, 1, 1] to compute
// dispatch size.
this.dispatch =
computeDispatch(this.dispatchLayout, this.outputShape, [1, 1, 1]);
}

this.inputShape = inputShape;
this.shaderKey = `argMinMax${this.op}`;
this.shaderKey = `argMinMax_${this.op}_${this.type}`;
}

getUserCode(): string {
const sharedMemorySnippet = `
var<workgroup> xBestIndices : array<i32, ${this.workGroupSize[0]}>;
var<workgroup> xBestValues : array<f32, ${this.workGroupSize[0]}>;
`;

const getInputShapeLastDim = () => {
if (this.inputShape.length === 1) {
return 'uniforms.xShape';
Expand All @@ -84,7 +89,12 @@ export class ArgMinMaxProgram implements WebGPUProgram {
return snippet;
};

const userCode = `
if (this.type === 'shared') {
const sharedMemorySnippet = `
var<workgroup> xBestIndices : array<i32, ${this.workGroupSize[0]}>;
var<workgroup> xBestValues : array<f32, ${this.workGroupSize[0]}>;
`;
const userCode = `
fn DIV_CEIL(a : u32, b : u32) -> u32 {
return ((a - 1u) / b + 1u);
}
Expand Down Expand Up @@ -131,6 +141,27 @@ export class ArgMinMaxProgram implements WebGPUProgram {
}
}
`;
return userCode;
return userCode;
} else {
const userCode = `
${main('index')} {
if (index < uniforms.size) {
let outputCoords = getCoordsFromIndex(index);
var bestIndex = 0;
var bestValue = getX(${splitOutputCoords()} 0);
let reduceLength = ${getInputShapeLastDim()};
for (var i = 1; i < reduceLength; i++) {
let candidate = getX(${splitOutputCoords()} i);
if (candidate ${this.op} bestValue) {
bestValue = candidate;
bestIndex = i;
}
}
setOutputAtIndexI32(index, bestIndex);
}
}
`;
return userCode;
}
}
}

0 comments on commit 25cd4f4

Please sign in to comment.