From 6d2b160072d72d3830d38d2ba75cf0cb20f6589f Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Mon, 14 Feb 2022 02:51:15 +0800 Subject: [PATCH 1/4] [webgpu] GatherV2 fill out of range values with zero --- tfjs-backend-webgpu/src/gather_webgpu.ts | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgpu/src/gather_webgpu.ts b/tfjs-backend-webgpu/src/gather_webgpu.ts index a5ed7d33b66..b1a3e54dae8 100644 --- a/tfjs-backend-webgpu/src/gather_webgpu.ts +++ b/tfjs-backend-webgpu/src/gather_webgpu.ts @@ -37,16 +37,18 @@ export class GatherProgram implements WebGPUProgram { this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - this.shaderKey = `gather`; + this.shaderKey = `gather_${aShape}_${outputShape}`; } getUserCode(): string { - const sourceCoords = getSourceCoords(this.aShape, 'i32'); + const sourceCoords = getSourceCoords(this.aShape); const userCode = ` ${getMainHeaderAndGlobalIndexString()} if (index < uniforms.size) { let resRC = getCoordsFromIndex(index); - setOutputAtIndex(index, getA(${sourceCoords})); + let indexZ = i32(getIndices(resRC.x, resRC.z)); + let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < ${this.aShape[2]}); + setOutputAtIndex(index, inBounds * getA(${sourceCoords})); } } `; @@ -55,12 +57,12 @@ export class GatherProgram implements WebGPUProgram { } // The input and output are always flattened into rank 4 tensors. -function getSourceCoords(aShape: number[], typePrefix = 'int'): string { +function getSourceCoords(aShape: number[]): string { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { if (i === 2) { - sourceCoords.push(`${typePrefix}(getIndices(resRC.x, resRC.z))`); + sourceCoords.push('indexZ'); } else { sourceCoords.push(`${currentCoords[i]}`); } From a8809b8c9bd08a98e40371d8a611214e5c7b68c9 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Fri, 18 Feb 2022 14:20:34 +0800 Subject: [PATCH 2/4] Add test case --- tfjs-backend-webgpu/src/gather_webgpu.ts | 2 +- tfjs-backend-webgpu/src/kernels/GatherV2.ts | 3 +- tfjs-backend-webgpu/src/webgpu_ops_test.ts | 37 +++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 tfjs-backend-webgpu/src/webgpu_ops_test.ts diff --git a/tfjs-backend-webgpu/src/gather_webgpu.ts b/tfjs-backend-webgpu/src/gather_webgpu.ts index b1a3e54dae8..761e882d93d 100644 --- a/tfjs-backend-webgpu/src/gather_webgpu.ts +++ b/tfjs-backend-webgpu/src/gather_webgpu.ts @@ -37,7 +37,7 @@ export class GatherProgram implements WebGPUProgram { this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - this.shaderKey = `gather_${aShape}_${outputShape}`; + this.shaderKey = `gather_${aShape}`; } getUserCode(): string { diff --git a/tfjs-backend-webgpu/src/kernels/GatherV2.ts b/tfjs-backend-webgpu/src/kernels/GatherV2.ts index bbcd4c68577..4356e854808 100644 --- a/tfjs-backend-webgpu/src/kernels/GatherV2.ts +++ b/tfjs-backend-webgpu/src/kernels/GatherV2.ts @@ -31,7 +31,8 @@ export function gatherV2( const {x, indices} = inputs; const {axis, batchDims} = attrs; - // Throw error when any index is out of bound. + // Unlike WebGL, WebGPU won't check if index is out of bound by calling + // backend.readSync() function in debug mode. const parsedAxis = util.parseAxisParam(axis, x.shape)[0]; const shapeInfo = backend_util.segment_util.collectGatherOpShapeInfo( diff --git a/tfjs-backend-webgpu/src/webgpu_ops_test.ts b/tfjs-backend-webgpu/src/webgpu_ops_test.ts new file mode 100644 index 00000000000..59ea9d46fa8 --- /dev/null +++ b/tfjs-backend-webgpu/src/webgpu_ops_test.ts @@ -0,0 +1,37 @@ +/** + * @license + * Copyright 2017 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {test_util} from '@tensorflow/tfjs-core'; +const expectArraysClose = test_util.expectArraysClose; +import * as tf from '@tensorflow/tfjs-core'; + +import {describeWebGPU} from './test_util'; + +describeWebGPU('gather', () => { + it('fills with zero when index is out of bound', async () => { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32'); + + const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32'); + const res = tf.gather(t, index); + const resInt = tf.gather(tInt, index); + + const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0]; + expectArraysClose(await res.data(), expected); + expectArraysClose(await resInt.data(), expected); + }); +}); From 8bae974f6a7b05ffeafe975c9486378062bfbb75 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Mon, 21 Feb 2022 11:15:36 +0800 Subject: [PATCH 3/4] address comments --- tfjs-backend-webgpu/src/gather_webgpu.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-backend-webgpu/src/gather_webgpu.ts b/tfjs-backend-webgpu/src/gather_webgpu.ts index 761e882d93d..26d0db130e1 100644 --- a/tfjs-backend-webgpu/src/gather_webgpu.ts +++ b/tfjs-backend-webgpu/src/gather_webgpu.ts @@ -37,7 +37,7 @@ export class GatherProgram implements WebGPUProgram { this.dispatchLayout = flatDispatchLayout(this.outputShape); this.dispatch = computeDispatch( this.dispatchLayout, this.outputShape, this.workGroupSize); - this.shaderKey = `gather_${aShape}`; + this.shaderKey = `gather`; } getUserCode(): string { @@ -47,7 +47,7 @@ export class GatherProgram implements WebGPUProgram { if (index < uniforms.size) { let resRC = getCoordsFromIndex(index); let indexZ = i32(getIndices(resRC.x, resRC.z)); - let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < ${this.aShape[2]}); + let inBounds = select(0.0, 1.0, indexZ >= 0 && indexZ < uniforms.aShape[2]); setOutputAtIndex(index, inBounds * getA(${sourceCoords})); } } From b612dead62397371d2dd4e52991e407cdd7b2a81 Mon Sep 17 00:00:00 2001 From: Yunfei Hao Date: Thu, 24 Feb 2022 02:24:57 +0800 Subject: [PATCH 4/4] Move common case to tfjs-core tests --- tfjs-backend-webgl/src/webgl_ops_test.ts | 15 --------- tfjs-backend-webgpu/src/webgpu_ops_test.ts | 37 ---------------------- tfjs-core/src/ops/gather_test.ts | 17 +++++++++- 3 files changed, 16 insertions(+), 53 deletions(-) delete mode 100644 tfjs-backend-webgpu/src/webgpu_ops_test.ts diff --git a/tfjs-backend-webgl/src/webgl_ops_test.ts b/tfjs-backend-webgl/src/webgl_ops_test.ts index 5ca94f61824..236c3fadf45 100644 --- a/tfjs-backend-webgl/src/webgl_ops_test.ts +++ b/tfjs-backend-webgl/src/webgl_ops_test.ts @@ -873,21 +873,6 @@ describeWithFlags('depthwiseConv2d packed', PACKED_ENVS, () => { }); }); -describeWithFlags('gather', WEBGL_ENVS, () => { - it('fills with zero when index is out of bound', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32'); - - const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32'); - const res = tf.gather(t, index); - const resInt = tf.gather(tInt, index); - - const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0]; - expectArraysClose(await res.data(), expected); - expectArraysClose(await resInt.data(), expected); - }); -}); - describeWithFlags('gather debug', WEBGL_ENVS, () => { it('throws if index is out of bound if in debug mode', async () => { const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); diff --git a/tfjs-backend-webgpu/src/webgpu_ops_test.ts b/tfjs-backend-webgpu/src/webgpu_ops_test.ts deleted file mode 100644 index 59ea9d46fa8..00000000000 --- a/tfjs-backend-webgpu/src/webgpu_ops_test.ts +++ /dev/null @@ -1,37 +0,0 @@ -/** - * @license - * Copyright 2017 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {test_util} from '@tensorflow/tfjs-core'; -const expectArraysClose = test_util.expectArraysClose; -import * as tf from '@tensorflow/tfjs-core'; - -import {describeWebGPU} from './test_util'; - -describeWebGPU('gather', () => { - it('fills with zero when index is out of bound', async () => { - const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); - const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32'); - - const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32'); - const res = tf.gather(t, index); - const resInt = tf.gather(tInt, index); - - const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0]; - expectArraysClose(await res.data(), expected); - expectArraysClose(await resInt.data(), expected); - }); -}); diff --git a/tfjs-core/src/ops/gather_test.ts b/tfjs-core/src/ops/gather_test.ts index 5b2a6279866..98fdda13ea6 100644 --- a/tfjs-core/src/ops/gather_test.ts +++ b/tfjs-core/src/ops/gather_test.ts @@ -19,7 +19,7 @@ import * as tf from '../index'; import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; -describeWithFlags('gather', ALL_ENVS, () => { +describeWithFlags('gather', ALL_ENVS, (env) => { it('1D (gather), scalar indices', async () => { const t = tf.tensor1d([1, 2, 3]); @@ -579,4 +579,19 @@ describeWithFlags('gather', ALL_ENVS, () => { expect(numTensorsAfter).toBe(numTensorsBefore); expect(numDataIdAfter).toBe(numDataIdBefore); }); + + it('fills with zero when index is out of bound', async () => { + if (env.backendName === 'webgl' || env.backendName === 'webgpu') { + const t = tf.tensor2d([1, 11, 2, 22], [2, 2]); + const tInt = tf.tensor2d([1, 11, 2, 22], [2, 2], 'int32'); + + const index = tf.tensor1d([0, 1, 100, -1, 2, -4], 'int32'); + const res = tf.gather(t, index); + const resInt = tf.gather(tInt, index); + + const expected = [1, 11, 2, 22, 0, 0, 0, 0, 0, 0, 0, 0]; + expectArraysClose(await res.data(), expected); + expectArraysClose(await resInt.data(), expected); + } + }); });