Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ export enum BinaryOpType {
DIV,
ELU_DER,
EQUAL,
FLOOR_DIV,
GREATER,
GREATER_EQUAL,
INT_DIV,
LESS,
LESS_EQUAL,
LOGICAL_AND,
Expand Down Expand Up @@ -56,6 +56,13 @@ const EQUAL = `
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a == b);
`;
const FLOOR_DIV = `
let remainder =
select(a % b, round(a % b), (round(a) == a) & (round(b) == b));
let quotient = (a - remainder) / b;
let resultTemp =
round(select(quotient, quotient - 1, sign(remainder) == -sign(b)));
`;
const GREATER = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
Expand All @@ -66,36 +73,6 @@ const GREATER_EQUAL = `
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a >= b);
`;

const INT_DIV = `
let s = sign(a) * sign(b);
let ia = i32(round(a));
let ib = i32(round(b));
return f32(idiv(ia, ib, s));
`;
const INT_DIV_VEC4 = `
let ia = vec4<i32>(round(a));
let ib = vec4<i32>(round(b));
let cond = ib != vec4<i32>(0);
var resultTemp = vec4<i32>(0);
let s = sign(a) * sign(b);

// Windows (D3D) wants guaranteed non-zero int division at compile-time.
if (cond[0]) {
resultTemp[0] = idiv(ia[0], ib[0], s[0]);
}
if (cond[1]) {
resultTemp[1] = idiv(ia[1], ib[1], s[1]);
}
if (cond[2]) {
resultTemp[2] = idiv(ia[2], ib[2], s[2]);
}
if (cond[3]) {
resultTemp[3] = idiv(ia[3], ib[3], s[3]);
}
return vec4<f32>(resultTemp);
`;

const LESS = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
Expand Down Expand Up @@ -265,14 +242,15 @@ export function getBinaryOpString(
case BinaryOpType.EQUAL:
doOpSnippet = EQUAL;
break;
case BinaryOpType.FLOOR_DIV:
doOpSnippet = FLOOR_DIV;
break;
case BinaryOpType.GREATER:
doOpSnippet = GREATER;
break;
case BinaryOpType.GREATER_EQUAL:
doOpSnippet = GREATER_EQUAL;
break;
case BinaryOpType.INT_DIV:
return useVec4 ? INT_DIV_VEC4 : INT_DIV;
case BinaryOpType.LESS:
doOpSnippet = LESS;
break;
Expand Down
8 changes: 5 additions & 3 deletions tfjs-backend-webgpu/src/kernels/FloorDiv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import {BinaryOpType} from '../binary_op_util';
import {binaryKernelFunc} from '../kernel_utils/kernel_funcs_utils';
import {floorDivImplCPU} from '../kernel_utils/shared';

export const floorDiv =
binaryKernelFunc({opType: BinaryOpType.INT_DIV,
cpuKernelImpl: floorDivImplCPU, dtype: 'int32'});
export const floorDiv = binaryKernelFunc({
opType: BinaryOpType.FLOOR_DIV,
cpuKernelImpl: floorDivImplCPU,
dtype: 'int32'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it state as 'int32'? This is the last thing that confuse me....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the result dtype.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, still LGTM.

});

export const floorDivConfig: KernelConfig = {
kernelName: FloorDiv,
Expand Down
7 changes: 0 additions & 7 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,6 @@ const TEST_FILTERS: TestFilter[] = [
'indices invalid',
],
},
{
startsWith: 'floorDiv ',
excludes: [
// float32 inputs with nonzero fractional part should not be rounded
'floorDiv float32',
],
},

// exclude unsupported kernels and to be fixed cases
{
Expand Down
9 changes: 0 additions & 9 deletions tfjs-backend-webgpu/src/webgpu_program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,6 @@ const commonSnippet = `
return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v;
}

fn idiv(a: i32, b: i32, sign: f32) -> i32 {
var res: i32 = a / b;
let modulo: i32 = a % b;
if (sign < 0. && modulo != 0) {
res = res - 1;
}
return res;
}

// NaN defination in IEEE 754-1985 is :
// - sign = either 0 or 1.
// - biased exponent = all 1 bits.
Expand Down