Skip to content

Commit

Permalink
[webgpu] Update INT_DIV
Browse files Browse the repository at this point in the history
  • Loading branch information
hujiajie committed Jul 13, 2023
1 parent e750c1b commit 4de937a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 52 deletions.
44 changes: 11 additions & 33 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ export enum BinaryOpType {
DIV,
ELU_DER,
EQUAL,
FLOOR_DIV,
GREATER,
GREATER_EQUAL,
INT_DIV,
LESS,
LESS_EQUAL,
LOGICAL_AND,
Expand Down Expand Up @@ -56,6 +56,13 @@ const EQUAL = `
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a == b);
`;
const FLOOR_DIV = `
let remainder =
select(a % b, round(a % b), (round(a) == a) & (round(b) == b));
let quotient = (a - remainder) / b;
let resultTemp =
round(select(quotient, quotient - 1, sign(remainder) == -sign(b)));
`;
const GREATER = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
Expand All @@ -66,36 +73,6 @@ const GREATER_EQUAL = `
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a >= b);
`;

const INT_DIV = `
let s = sign(a) * sign(b);
let ia = i32(round(a));
let ib = i32(round(b));
return f32(idiv(ia, ib, s));
`;
const INT_DIV_VEC4 = `
let ia = vec4<i32>(round(a));
let ib = vec4<i32>(round(b));
let cond = ib != vec4<i32>(0);
var resultTemp = vec4<i32>(0);
let s = sign(a) * sign(b);
// Windows (D3D) wants guaranteed non-zero int division at compile-time.
if (cond[0]) {
resultTemp[0] = idiv(ia[0], ib[0], s[0]);
}
if (cond[1]) {
resultTemp[1] = idiv(ia[1], ib[1], s[1]);
}
if (cond[2]) {
resultTemp[2] = idiv(ia[2], ib[2], s[2]);
}
if (cond[3]) {
resultTemp[3] = idiv(ia[3], ib[3], s[3]);
}
return vec4<f32>(resultTemp);
`;

const LESS = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
Expand Down Expand Up @@ -265,14 +242,15 @@ export function getBinaryOpString(
case BinaryOpType.EQUAL:
doOpSnippet = EQUAL;
break;
case BinaryOpType.FLOOR_DIV:
doOpSnippet = FLOOR_DIV;
break;
case BinaryOpType.GREATER:
doOpSnippet = GREATER;
break;
case BinaryOpType.GREATER_EQUAL:
doOpSnippet = GREATER_EQUAL;
break;
case BinaryOpType.INT_DIV:
return useVec4 ? INT_DIV_VEC4 : INT_DIV;
case BinaryOpType.LESS:
doOpSnippet = LESS;
break;
Expand Down
8 changes: 5 additions & 3 deletions tfjs-backend-webgpu/src/kernels/FloorDiv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import {BinaryOpType} from '../binary_op_util';
import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
import {floorDivImplCPU} from '../kernel_utils/shared';

export const floorDiv =
binaryKernelFunc({opType: BinaryOpType.INT_DIV,
cpuKernelImpl: floorDivImplCPU, dtype: 'int32'});
export const floorDiv = binaryKernelFunc({
opType: BinaryOpType.FLOOR_DIV,
cpuKernelImpl: floorDivImplCPU,
dtype: 'int32'
});

export const floorDivConfig: KernelConfig = {
kernelName: FloorDiv,
Expand Down
7 changes: 0 additions & 7 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,6 @@ const TEST_FILTERS: TestFilter[] = [
'indices invalid',
],
},
{
startsWith: 'floorDiv ',
excludes: [
// float32 inputs with nonzero fractional part should not be rounded
'floorDiv float32',
],
},

// exclude unsupported kernels and to be fixed cases
{
Expand Down
9 changes: 0 additions & 9 deletions tfjs-backend-webgpu/src/webgpu_program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,6 @@ const commonSnippet = `
return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v;
}
fn idiv(a: i32, b: i32, sign: f32) -> i32 {
var res: i32 = a / b;
let modulo: i32 = a % b;
if (sign < 0. && modulo != 0) {
res = res - 1;
}
return res;
}
// NaN defination in IEEE 754-1985 is :
// - sign = either 0 or 1.
// - biased exponent = all 1 bits.
Expand Down

0 comments on commit 4de937a

Please sign in to comment.