Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webgpu gather zero #6166

Merged
merged 9 commits into from
Mar 28, 2022
15 changes: 0 additions & 15 deletions tfjs-backend-webgl/src/webgl_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -873,21 +873,6 @@ describeWithFlags('depthwiseConv2d packed', PACKED_ENVS, () => {
});
});

describeWithFlags('gather', WEBGL_ENVS, () => {
it('fills with zero when index is out of bound', async () => {
const t = tf.tensor2d([1, 11, 2, 22], [2, 2]);
const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32');

const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32');
const res = tf.gather(t, index);
const resInt = tf.gather(tInt, index);

const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0];
expectArraysClose(await res.data(), expected);
expectArraysClose(await resInt.data(), expected);
});
});

describeWithFlags('gather debug', WEBGL_ENVS, () => {
it('throws if index is out of bound if in debug mode', async () => {
const t = tf.tensor2d([1, 11, 2, 22], [2, 2]);
Expand Down
10 changes: 6 additions & 4 deletions tfjs-backend-webgpu/src/gather_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ export class GatherProgram implements WebGPUProgram {
}

getUserCode(): string {
const sourceCoords = getSourceCoords(this.aShape, 'i32');
const sourceCoords = getSourceCoords(this.aShape);
const userCode = `
${getMainHeaderAndGlobalIndexString()}
if (index < uniforms.size) {
let resRC = getCoordsFromIndex(index);
setOutputAtIndex(index, getA(${sourceCoords}));
let indexZ = i32(getIndices(resRC.x, resRC.z));
let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]);
setOutputAtIndex(index, inBounds * getA(${sourceCoords}));
}
}
`;
Expand All @@ -55,12 +57,12 @@ export class GatherProgram implements WebGPUProgram {
}

// The input and output are always flattened into rank 4 tensors.
function getSourceCoords(aShape: number[], typePrefix = 'int'): string {
function getSourceCoords(aShape: number[]): string {
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
const sourceCoords = [];
for (let i = 0; i < aShape.length; i++) {
if (i === 2) {
sourceCoords.push(`${typePrefix}(getIndices(resRC.x, resRC.z))`);
sourceCoords.push('indexZ');
} else {
sourceCoords.push(`${currentCoords[i]}`);
}
Expand Down
3 changes: 2 additions & 1 deletion tfjs-backend-webgpu/src/kernels/GatherV2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ export function gatherV2(
const {x, indices} = inputs;
const {axis, batchDims} = attrs;

// Throw error when any index is out of bound.
// Unlike WebGL, WebGPU won't check if index is out of bound by calling
// backend.readSync() function in debug mode.
const parsedAxis = util.parseAxisParam(axis, x.shape)[0];

const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo(
Expand Down
17 changes: 16 additions & 1 deletion tfjs-core/src/ops/gather_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

describeWithFlags('gather', ALL_ENVS, () => {
describeWithFlags('gather', ALL_ENVS, (env) => {
it('1D (gather), scalar indices', async () => {
const t = tf.tensor1d([1, 2, 3]);

Expand Down Expand Up @@ -579,4 +579,19 @@ describeWithFlags('gather', ALL_ENVS, () => {
expect(numTensorsAfter).toBe(numTensorsBefore);
expect(numDataIdAfter).toBe(numDataIdBefore);
});

it('fills with zero when index is out of bound', async () => {
if (env.backendName === 'webgl' || env.backendName === 'webgpu') {
const t = tf.tensor2d([1, 11, 2, 22], [2, 2]);
const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32');

const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32');
const res = tf.gather(t, index);
const resInt = tf.gather(tInt, index);

const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0];
expectArraysClose(await res.data(), expected);
expectArraysClose(await resInt.data(), expected);
}
});
});