Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Update EQUAL,GREATER,GREATER_EQUAL,LESS,LESS_EQUAL #7751

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,21 @@ const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;';
const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;';
const DIV = 'let resultTemp = a / b;';
const ELU_DER = 'let resultTemp = select(a * (b + 1.0), a, b >= b - b);';
const EQUAL = 'return f32(a == b);';
const EQUAL_VEC4 = 'return vec4<f32>(a == b);';
const GREATER = 'return f32(a > b);';
const GREATER_VEC4 = 'return vec4<f32>(a > b);';
const GREATER_EQUAL = 'return f32(a >= b);';
const GREATER_EQUAL_VEC4 = 'return vec4<f32>(a >= b);';
const EQUAL = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a == b);
`;
const GREATER = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a > b);
`;
const GREATER_EQUAL = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a >= b);
`;

const INT_DIV = `
let s = sign(a) * sign(b);
Expand Down Expand Up @@ -87,10 +96,16 @@ const INT_DIV_VEC4 = `
return vec4<f32>(resultTemp);
`;

const LESS = 'return f32(a < b);';
const LESS_VEC4 = 'return vec4<f32>(a < b);';
const LESS_EQUAL = 'return f32(a <= b);';
const LESS_EQUAL_VEC4 = 'return vec4<f32>(a <= b);';
const LESS = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a < b);
`;
const LESS_EQUAL = `
let zero = sign(a) * 0 + 0;
let one = sign(b) * 0 + 1;
let resultTemp = select(zero, one, a <= b);
`;
const LOGICAL_AND = 'return f32(a >= 1.0 && b >= 1.0);';
const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
vec4<f32>(b >= vec4<f32>(1.0)));`;
Expand Down Expand Up @@ -248,17 +263,22 @@ export function getBinaryOpString(
doOpSnippet = ELU_DER;
break;
case BinaryOpType.EQUAL:
return useVec4 ? EQUAL_VEC4 : EQUAL;
doOpSnippet = EQUAL;
break;
case BinaryOpType.GREATER:
return useVec4 ? GREATER_VEC4 : GREATER;
doOpSnippet = GREATER;
break;
case BinaryOpType.GREATER_EQUAL:
return useVec4 ? GREATER_EQUAL_VEC4 : GREATER_EQUAL;
doOpSnippet = GREATER_EQUAL;
break;
case BinaryOpType.INT_DIV:
return useVec4 ? INT_DIV_VEC4 : INT_DIV;
case BinaryOpType.LESS:
return useVec4 ? LESS_VEC4 : LESS;
doOpSnippet = LESS;
break;
case BinaryOpType.LESS_EQUAL:
return useVec4 ? LESS_EQUAL_VEC4 : LESS_EQUAL;
doOpSnippet = LESS_EQUAL;
break;
case BinaryOpType.LOGICAL_AND:
return useVec4 ? LOGICAL_AND_VEC4 : LOGICAL_AND;
case BinaryOpType.LOGICAL_OR:
Expand Down